59 lines
1.5 KiB
Python
59 lines
1.5 KiB
Python
import numpy as np
|
||
from numpy.typing import NDArray
|
||
|
||
from .primitive import Operation
|
||
|
||
type Value = NDArray[np.float64]
|
||
|
||
# Унарные операции
|
||
NEG = Operation("-", 1, lambda x: -x[0])
|
||
SIN = Operation("sin", 1, lambda x: np.sin(x[0]))
|
||
COS = Operation("cos", 1, lambda x: np.cos(x[0]))
|
||
|
||
|
||
def _safe_exp(v: Value) -> Value:
|
||
v_clipped = np.clip(v, -10.0, 10.0)
|
||
out = np.exp(v_clipped)
|
||
out[np.isnan(out) | np.isinf(out)] = 0.0
|
||
return out
|
||
|
||
|
||
EXP = Operation("exp", 1, lambda x: _safe_exp(x[0]))
|
||
|
||
|
||
# Бинарные операции
|
||
ADD = Operation("+", 2, lambda x: x[0] + x[1])
|
||
SUB = Operation("-", 2, lambda x: x[0] - x[1])
|
||
MUL = Operation("*", 2, lambda x: x[0] * x[1])
|
||
|
||
|
||
def _safe_div(a: Value, b: Value) -> Value:
|
||
eps = 1e-12
|
||
denom = np.where(np.abs(b) >= eps, b, eps)
|
||
out = np.divide(a, denom)
|
||
out = np.where(np.isnan(out) | np.isinf(out), 0.0, out)
|
||
return out
|
||
|
||
|
||
DIV = Operation("/", 2, lambda x: _safe_div(x[0], x[1]))
|
||
|
||
|
||
def _safe_pow(a: Value, b: Value) -> Value:
|
||
a_clip = np.clip(a, -1e3, 1e3)
|
||
b_clip = np.clip(b, -3.0, 3.0)
|
||
|
||
# 0 в отрицательной степени → 0
|
||
mask_zero_neg = (a_clip == 0.0) & (b_clip < 0.0)
|
||
with np.errstate(over="ignore", invalid="ignore", divide="ignore", under="ignore"):
|
||
out = np.power(a_clip, b_clip)
|
||
|
||
out[mask_zero_neg] = 0.0
|
||
out[np.isnan(out) | np.isinf(out)] = 0.0
|
||
return out
|
||
|
||
|
||
POW = Operation("^", 2, lambda x: _safe_pow(x[0], x[1]))
|
||
|
||
# Все операции в либе
|
||
ALL = (NEG, SIN, COS, EXP, ADD, SUB, MUL, DIV, POW)
|