119 lines
4.3 KiB
Python
119 lines
4.3 KiB
Python
import random
|
||
from typing import Sequence
|
||
|
||
from .primitive import Primitive
|
||
from .types import Context, Value
|
||
|
||
|
||
class Node:
|
||
def __init__(self, value: Primitive):
|
||
self.value = value
|
||
self.parent: Node | None = None
|
||
self.children: list[Node] = []
|
||
|
||
def add_child(self, child: Node) -> None:
|
||
self.children.append(child)
|
||
child.parent = self
|
||
|
||
def remove_child(self, child: Node) -> None:
|
||
self.children.remove(child)
|
||
child.parent = None
|
||
|
||
def replace_child(self, old_child: Node, new_child: Node) -> None:
|
||
self.children[self.children.index(old_child)] = new_child
|
||
old_child.parent = None
|
||
new_child.parent = self
|
||
|
||
def remove_children(self) -> None:
|
||
for child in self.children:
|
||
child.parent = None
|
||
self.children = []
|
||
|
||
def copy_subtree(self) -> Node:
|
||
node = Node(self.value)
|
||
for child in self.children:
|
||
node.add_child(child.copy_subtree())
|
||
return node
|
||
|
||
def list_nodes(self) -> list[Node]:
|
||
"""Список всех узлов поддерева, начиная с текущего (aka depth-first-search)."""
|
||
nodes: list[Node] = [self]
|
||
for child in self.children:
|
||
nodes.extend(child.list_nodes())
|
||
return nodes
|
||
|
||
def prune(self, terminals: Sequence[Primitive], max_depth: int) -> None:
|
||
"""Усечение поддерева до заданной глубины.
|
||
|
||
Заменяет операции на глубине max_depth на случайные терминалы.
|
||
"""
|
||
|
||
def prune_recursive(node: Node, current_depth: int) -> None:
|
||
if node.value.arity == 0: # Терминалы остаются без изменений
|
||
return
|
||
|
||
if current_depth >= max_depth:
|
||
node.remove_children()
|
||
node.value = random.choice(terminals)
|
||
return
|
||
|
||
for child in node.children:
|
||
prune_recursive(child, current_depth + 1)
|
||
|
||
prune_recursive(self, 1)
|
||
|
||
def get_depth(self) -> int:
|
||
"""Вычисляет глубину поддерева, начиная с текущего узла."""
|
||
return (
|
||
max(child.get_depth() for child in self.children) + 1
|
||
if self.children
|
||
else 1
|
||
)
|
||
|
||
def get_size(self) -> int:
|
||
"""Вычисляет размер поддерева, начиная с текущего узла."""
|
||
return sum(child.get_size() for child in self.children) + 1
|
||
|
||
def get_level(self) -> int:
|
||
"""Вычисляет уровень узла в дереве (расстояние от корня). Корень имеет уровень 1."""
|
||
return self.parent.get_level() + 1 if self.parent else 1
|
||
|
||
def eval(self, context: Context) -> Value:
|
||
return self.value.eval(
|
||
[child.eval(context) for child in self.children], context
|
||
)
|
||
|
||
def __str__(self) -> str:
|
||
"""Рекурсивный перевод древовидного вида формулы в строку в инфиксной форме."""
|
||
if self.value.arity == 0:
|
||
return self.value.name
|
||
|
||
if self.value.arity == 2:
|
||
return f"({self.children[0]} {self.value.name} {self.children[1]})"
|
||
|
||
return f"{self.value.name}({', '.join(str(child) for child in self.children)})"
|
||
|
||
def to_str_tree(self, prefix="", is_last: bool = True) -> str:
|
||
"""Строковое представление древовидной структуры."""
|
||
lines = prefix + ("└── " if is_last else "├── ") + self.value.name + "\n"
|
||
child_prefix = prefix + (" " if is_last else "│ ")
|
||
for i, child in enumerate(self.children):
|
||
is_child_last = i == len(self.children) - 1
|
||
lines += child.to_str_tree(child_prefix, is_child_last)
|
||
|
||
return lines
|
||
|
||
|
||
def swap_subtrees(a: Node, b: Node) -> None:
|
||
if a.parent is None or b.parent is None:
|
||
raise ValueError("Нельзя обменять корни деревьев")
|
||
|
||
# Сохраняем ссылки на родителей
|
||
a_parent = a.parent
|
||
b_parent = b.parent
|
||
|
||
i = a_parent.children.index(a)
|
||
j = b_parent.children.index(b)
|
||
a_parent.children[i], b_parent.children[j] = b, a
|
||
a.parent, b.parent = b_parent, a_parent
|