Files
genetic-algorithms/ga/generate_plots.py
2026-04-06 12:31:28 +03:00

141 lines
4.4 KiB
Python

#!/usr/bin/env python3
"""Generate plots from GA history for the course work report."""
import json
from pathlib import Path
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update(
{
"font.family": "DejaVu Sans",
"axes.grid": True,
"grid.alpha": 0.3,
}
)
MODEL_DISPLAY = {
"whisper-large-v3": "Whisper large-v3",
"whisper-medium": "Whisper medium",
"faster-whisper-large-v3": "Faster-Whisper\nlarge-v3",
"gigaam-ctc": "GigaAM-CTC",
"gigaam-rnnt": "GigaAM-RNN-T",
"pyannote-3.1": "pyannote 3.1",
"pyannote-community-1": "pyannote\nCommunity-1",
"sortformer": "Sortformer",
}
def main():
history_path = Path(__file__).parent / "history.json"
data = json.loads(history_path.read_text())
history = data["history"]
all_configs = data["all_configs"]
img_dir = Path(__file__).parent.parent / "report" / "img"
img_dir.mkdir(parents=True, exist_ok=True)
plot_convergence(history, img_dir)
plot_wer_der_scatter(all_configs, img_dir)
plot_model_frequency(all_configs, img_dir)
def plot_convergence(history: list[dict], img_dir: Path):
gens = [h["generation"] for h in history]
best = [-h["best_fitness"] for h in history]
mean = [-h["mean_fitness"] for h in history]
fig, ax = plt.subplots(figsize=(7, 4.5))
ax.plot(gens, best, "b-o", markersize=4, linewidth=1.5, label="Лучшая особь")
ax.plot(
gens, mean, "r--s", markersize=3, linewidth=1.2, label="Среднее по популяции"
)
ax.set_xlabel("Поколение", fontsize=12)
ax.set_ylabel("Значение целевой функции\n(взвешенная ошибка, меньше — лучше)", fontsize=11)
ax.legend(fontsize=11)
fig.tight_layout()
fig.savefig(img_dir / "convergence.png", dpi=150)
plt.close(fig)
print(f"Saved {img_dir / 'convergence.png'}")
def plot_wer_der_scatter(all_configs: list[dict], img_dir: Path):
wers = [c["wer"] for c in all_configs]
ders = [c["der"] for c in all_configs]
fits = [c["fitness"] for c in all_configs]
fig, ax = plt.subplots(figsize=(7, 5.5))
sc = ax.scatter(
wers,
ders,
c=fits,
cmap="RdYlGn",
alpha=0.7,
edgecolors="gray",
linewidth=0.5,
s=40,
)
best = max(all_configs, key=lambda c: c["fitness"])
ax.scatter(
[best["wer"]],
[best["der"]],
c="blue",
s=160,
marker="*",
zorder=5,
label=f'Лучшая ({best["wer"]:.1f}%, {best["der"]:.1f}%)',
)
ax.set_xlabel("WER, %", fontsize=12)
ax.set_ylabel("DER, %", fontsize=12)
ax.legend(fontsize=11)
cbar = fig.colorbar(sc, ax=ax)
cbar.set_label("Фитнес", fontsize=11)
fig.tight_layout()
fig.savefig(img_dir / "wer_der_scatter.png", dpi=150)
plt.close(fig)
print(f"Saved {img_dir / 'wer_der_scatter.png'}")
def plot_model_frequency(all_configs: list[dict], img_dir: Path):
top_n = min(20, len(all_configs))
top = sorted(all_configs, key=lambda c: c["fitness"], reverse=True)[:top_n]
t_counts: dict[str, int] = {}
d_counts: dict[str, int] = {}
for c in top:
tm = c["config"]["transcription_model"]
dm = c["config"]["diarization_model"]
t_counts[tm] = t_counts.get(tm, 0) + 1
d_counts[dm] = d_counts.get(dm, 0) + 1
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.5))
t_names = sorted(t_counts.keys(), key=lambda n: t_counts[n], reverse=True)
t_labels = [MODEL_DISPLAY.get(n, n) for n in t_names]
t_vals = [t_counts[n] for n in t_names]
ax1.barh(t_labels, t_vals, color="steelblue")
ax1.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
ax1.set_title("Модели транскрибации", fontsize=12)
ax1.invert_yaxis()
d_names = sorted(d_counts.keys(), key=lambda n: d_counts[n], reverse=True)
d_labels = [MODEL_DISPLAY.get(n, n) for n in d_names]
d_vals = [d_counts[n] for n in d_names]
ax2.barh(d_labels, d_vals, color="coral")
ax2.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
ax2.set_title("Модели диаризации", fontsize=12)
ax2.invert_yaxis()
fig.tight_layout()
fig.savefig(img_dir / "model_frequency.png", dpi=150)
plt.close(fig)
print(f"Saved {img_dir / 'model_frequency.png'}")
if __name__ == "__main__":
main()