from dataclasses import dataclass from typing import Callable, Sequence from .types import Context, Value type OperationFn = Callable[[Sequence[Value]], Value] @dataclass(frozen=True) class Primitive: name: str arity: int operation_fn: OperationFn | None def eval(self, args: Sequence[Value], context: Context) -> Value: if self.operation_fn is None: return context[self] return self.operation_fn(args) def __post_init__(self) -> None: if self.arity != 0 and self.operation_fn is None: raise ValueError("Operation is required for primitive with non-zero arity") def Var(name: str) -> Primitive: return Primitive(name=name, arity=0, operation_fn=None) def Const(name: str, val: Value) -> Primitive: return Primitive(name=name, arity=0, operation_fn=lambda _args: val) def Operation(name: str, arity: int, operation_fn: OperationFn) -> Primitive: return Primitive(name=name, arity=arity, operation_fn=operation_fn)