га
This commit is contained in:
5
ga/.gitignore
vendored
Normal file
5
ga/.gitignore
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
*
|
||||
|
||||
!**/
|
||||
!.gitignore
|
||||
!*.py
|
||||
361
ga/ga.py
Normal file
361
ga/ga.py
Normal file
@@ -0,0 +1,361 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Genetic algorithm for optimizing meeting transcription+diarization pipeline.
|
||||
|
||||
Searches over a mixed discrete configuration space of transcription and
|
||||
diarization models and their parameters. Uses module-level caching and batch
|
||||
scheduling grouped by model to minimize redundant computations.
|
||||
"""
|
||||
|
||||
import json
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import run_pipeline
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration space
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TRANSCRIPTION_MODELS = [
|
||||
"whisper-large-v3",
|
||||
"whisper-medium",
|
||||
"faster-whisper-large-v3",
|
||||
"gigaam-ctc",
|
||||
"gigaam-rnnt",
|
||||
]
|
||||
BEAM_SIZES = [1, 3, 5, 7, 10]
|
||||
VAD_THRESHOLDS = [0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
|
||||
DIARIZATION_MODELS = ["pyannote-3.1", "pyannote-community-1", "sortformer"]
|
||||
MIN_SPEECH_DURATIONS = [0.25, 0.5, 0.75, 1.0, 1.5]
|
||||
CLUSTERING_THRESHOLDS = [0.3, 0.45, 0.6, 0.75, 0.9]
|
||||
|
||||
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
|
||||
|
||||
GENES: list[tuple[str, list]] = [
|
||||
("transcription_model", TRANSCRIPTION_MODELS),
|
||||
("beam_size", BEAM_SIZES),
|
||||
("vad_threshold", VAD_THRESHOLDS),
|
||||
("diarization_model", DIARIZATION_MODELS),
|
||||
("min_speech_duration", MIN_SPEECH_DURATIONS),
|
||||
("clustering_threshold", CLUSTERING_THRESHOLDS),
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GA hyper-parameters
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
POPULATION_SIZE = 15
|
||||
NUM_GENERATIONS = 25
|
||||
TOURNAMENT_SIZE = 3
|
||||
MUTATION_PROB = 0.15
|
||||
ELITE_COUNT = 2
|
||||
|
||||
ALPHA = 0.4 # WER weight
|
||||
BETA = 0.4 # DER weight
|
||||
GAMMA = 0.2 # time weight
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chromosome
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chromosome:
|
||||
genes: list[int]
|
||||
fitness: float | None = None
|
||||
wer: float | None = None
|
||||
der: float | None = None
|
||||
time_min: float | None = None
|
||||
|
||||
def to_config(self) -> dict:
|
||||
return {
|
||||
name: values[self.genes[i]] for i, (name, values) in enumerate(GENES)
|
||||
}
|
||||
|
||||
def transcription_key(self) -> tuple:
|
||||
cfg = self.to_config()
|
||||
model = cfg["transcription_model"]
|
||||
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
|
||||
return (model, beam, cfg["vad_threshold"])
|
||||
|
||||
def diarization_key(self) -> tuple:
|
||||
cfg = self.to_config()
|
||||
return (
|
||||
cfg["diarization_model"],
|
||||
cfg["min_speech_duration"],
|
||||
cfg["clustering_threshold"],
|
||||
cfg["vad_threshold"],
|
||||
)
|
||||
|
||||
def copy(self) -> "Chromosome":
|
||||
return Chromosome(genes=self.genes.copy())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, cache_dir: Path):
|
||||
self.cache_dir = cache_dir
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.transcription: dict[str, dict] = {}
|
||||
self.diarization: dict[str, dict] = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
for name in ("transcription", "diarization"):
|
||||
path = self.cache_dir / f"{name}.json"
|
||||
if path.exists():
|
||||
setattr(self, name, json.loads(path.read_text()))
|
||||
|
||||
def save(self):
|
||||
for name in ("transcription", "diarization"):
|
||||
path = self.cache_dir / f"{name}.json"
|
||||
path.write_text(json.dumps(getattr(self, name), indent=2))
|
||||
|
||||
def get_transcription(self, key: tuple) -> dict | None:
|
||||
return self.transcription.get(str(key))
|
||||
|
||||
def set_transcription(self, key: tuple, result: dict):
|
||||
self.transcription[str(key)] = result
|
||||
|
||||
def get_diarization(self, key: tuple) -> dict | None:
|
||||
return self.diarization.get(str(key))
|
||||
|
||||
def set_diarization(self, key: tuple, result: dict):
|
||||
self.diarization[str(key)] = result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GA operators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def random_chromosome() -> Chromosome:
|
||||
return Chromosome(genes=[random.randint(0, len(v) - 1) for _, v in GENES])
|
||||
|
||||
|
||||
def tournament_select(population: list[Chromosome]) -> Chromosome:
|
||||
candidates = random.sample(population, TOURNAMENT_SIZE)
|
||||
return max(candidates, key=lambda c: c.fitness)
|
||||
|
||||
|
||||
def crossover(p1: Chromosome, p2: Chromosome) -> Chromosome:
|
||||
return Chromosome(
|
||||
genes=[random.choice([g1, g2]) for g1, g2 in zip(p1.genes, p2.genes)]
|
||||
)
|
||||
|
||||
|
||||
def mutate(chrom: Chromosome) -> Chromosome:
|
||||
genes = chrom.genes.copy()
|
||||
for i, (_, values) in enumerate(GENES):
|
||||
if random.random() < MUTATION_PROB:
|
||||
if len(values) > 2 and random.random() < 0.7:
|
||||
delta = random.choice([-1, 1])
|
||||
genes[i] = max(0, min(len(values) - 1, genes[i] + delta))
|
||||
else:
|
||||
genes[i] = random.randint(0, len(values) - 1)
|
||||
return Chromosome(genes=genes)
|
||||
|
||||
|
||||
def compute_fitness(wer: float, der: float, time_min: float) -> float:
|
||||
return -(ALPHA * wer + BETA * der + GAMMA * time_min)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch scheduler
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def schedule_evaluations(
|
||||
population: list[Chromosome], cache: Cache, audio_paths: list[str]
|
||||
) -> int:
|
||||
"""Evaluate chromosomes using cache and batching by model.
|
||||
|
||||
1. Collect unique uncached transcription and diarization configs.
|
||||
2. Group them by model so the pipeline loads each model only once.
|
||||
3. Store results in cache and assemble fitness values.
|
||||
|
||||
Returns the number of new (uncached) module evaluations performed.
|
||||
"""
|
||||
uncached_t: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
|
||||
uncached_d: dict[str, list[tuple[tuple, dict]]] = defaultdict(list)
|
||||
seen_t: set[str] = set()
|
||||
seen_d: set[str] = set()
|
||||
|
||||
for chrom in population:
|
||||
cfg = chrom.to_config()
|
||||
|
||||
t_key = chrom.transcription_key()
|
||||
t_key_s = str(t_key)
|
||||
if cache.get_transcription(t_key) is None and t_key_s not in seen_t:
|
||||
seen_t.add(t_key_s)
|
||||
model = cfg["transcription_model"]
|
||||
beam = cfg["beam_size"] if model in WHISPER_MODELS else 1
|
||||
uncached_t[model].append(
|
||||
(t_key, {"beam_size": beam, "vad_threshold": cfg["vad_threshold"]})
|
||||
)
|
||||
|
||||
d_key = chrom.diarization_key()
|
||||
d_key_s = str(d_key)
|
||||
if cache.get_diarization(d_key) is None and d_key_s not in seen_d:
|
||||
seen_d.add(d_key_s)
|
||||
uncached_d[cfg["diarization_model"]].append(
|
||||
(
|
||||
d_key,
|
||||
{
|
||||
"min_speech_duration": cfg["min_speech_duration"],
|
||||
"clustering_threshold": cfg["clustering_threshold"],
|
||||
"vad_threshold": cfg["vad_threshold"],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
new_evals = 0
|
||||
|
||||
for model, items in uncached_t.items():
|
||||
configs = [c for _, c in items]
|
||||
results = run_pipeline.evaluate_transcription_batch(
|
||||
model, configs, audio_paths
|
||||
)
|
||||
for (key, _), result in zip(items, results):
|
||||
cache.set_transcription(key, result)
|
||||
new_evals += 1
|
||||
|
||||
for model, items in uncached_d.items():
|
||||
configs = [c for _, c in items]
|
||||
results = run_pipeline.evaluate_diarization_batch(
|
||||
model, configs, audio_paths
|
||||
)
|
||||
for (key, _), result in zip(items, results):
|
||||
cache.set_diarization(key, result)
|
||||
new_evals += 1
|
||||
|
||||
if new_evals > 0:
|
||||
cache.save()
|
||||
|
||||
for chrom in population:
|
||||
t_res = cache.get_transcription(chrom.transcription_key())
|
||||
d_res = cache.get_diarization(chrom.diarization_key())
|
||||
chrom.wer = t_res["wer"]
|
||||
chrom.der = d_res["der"]
|
||||
chrom.time_min = t_res["time"] + d_res["time"]
|
||||
chrom.fitness = compute_fitness(chrom.wer, chrom.der, chrom.time_min)
|
||||
|
||||
return new_evals
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main GA loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_ga(audio_paths: list[str] | None = None, seed: int = 42) -> list[dict]:
|
||||
random.seed(seed)
|
||||
if audio_paths is None:
|
||||
audio_paths = []
|
||||
|
||||
cache = Cache(Path(__file__).parent / "cache")
|
||||
history: list[dict] = []
|
||||
all_configs: list[dict] = []
|
||||
seen_genes: set[tuple[int, ...]] = set()
|
||||
total_evals = 0
|
||||
|
||||
population = [random_chromosome() for _ in range(POPULATION_SIZE)]
|
||||
new_evals = schedule_evaluations(population, cache, audio_paths)
|
||||
total_evals += new_evals
|
||||
|
||||
for gen in range(NUM_GENERATIONS):
|
||||
population.sort(key=lambda c: c.fitness, reverse=True)
|
||||
|
||||
for chrom in population:
|
||||
key = tuple(chrom.genes)
|
||||
if key not in seen_genes:
|
||||
seen_genes.add(key)
|
||||
all_configs.append(
|
||||
{
|
||||
"config": chrom.to_config(),
|
||||
"wer": chrom.wer,
|
||||
"der": chrom.der,
|
||||
"time": chrom.time_min,
|
||||
"fitness": chrom.fitness,
|
||||
"generation": gen,
|
||||
}
|
||||
)
|
||||
|
||||
best = population[0]
|
||||
mean_fit = sum(c.fitness for c in population) / len(population)
|
||||
|
||||
history.append(
|
||||
{
|
||||
"generation": gen,
|
||||
"best_fitness": round(best.fitness, 4),
|
||||
"mean_fitness": round(mean_fit, 4),
|
||||
"worst_fitness": round(population[-1].fitness, 4),
|
||||
"best_config": best.to_config(),
|
||||
"best_wer": best.wer,
|
||||
"best_der": best.der,
|
||||
"best_time": best.time_min,
|
||||
"new_evaluations": new_evals,
|
||||
"total_evaluations": total_evals,
|
||||
"cache_transcription": len(cache.transcription),
|
||||
"cache_diarization": len(cache.diarization),
|
||||
}
|
||||
)
|
||||
|
||||
print(
|
||||
f"Gen {gen:3d} | best={best.fitness:.3f} mean={mean_fit:.3f} | "
|
||||
f"WER={best.wer:.1f}% DER={best.der:.1f}% | "
|
||||
f"new={new_evals} cache_t={len(cache.transcription)} "
|
||||
f"cache_d={len(cache.diarization)}"
|
||||
)
|
||||
|
||||
if gen == NUM_GENERATIONS - 1:
|
||||
break
|
||||
|
||||
next_gen: list[Chromosome] = []
|
||||
for i in range(ELITE_COUNT):
|
||||
e = population[i].copy()
|
||||
e.fitness = population[i].fitness
|
||||
e.wer = population[i].wer
|
||||
e.der = population[i].der
|
||||
e.time_min = population[i].time_min
|
||||
next_gen.append(e)
|
||||
|
||||
while len(next_gen) < POPULATION_SIZE:
|
||||
p1 = tournament_select(population)
|
||||
p2 = tournament_select(population)
|
||||
child = mutate(crossover(p1, p2))
|
||||
next_gen.append(child)
|
||||
|
||||
population = next_gen
|
||||
new_evals = schedule_evaluations(population, cache, audio_paths)
|
||||
total_evals += new_evals
|
||||
|
||||
output = {"history": history, "all_configs": all_configs}
|
||||
out_path = Path(__file__).parent / "history.json"
|
||||
out_path.write_text(json.dumps(output, indent=2, ensure_ascii=False))
|
||||
print(f"\nResults saved to {out_path}")
|
||||
|
||||
population.sort(key=lambda c: c.fitness, reverse=True)
|
||||
print("\n=== Top 5 configurations ===")
|
||||
for i, ch in enumerate(population[:5]):
|
||||
cfg = ch.to_config()
|
||||
print(
|
||||
f"\n#{i + 1}: fitness={ch.fitness:.3f} "
|
||||
f"WER={ch.wer:.2f}% DER={ch.der:.2f}% time={ch.time_min:.2f}min"
|
||||
)
|
||||
for k, v in cfg.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
return history
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_ga()
|
||||
154
ga/generate_plots.py
Normal file
154
ga/generate_plots.py
Normal file
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate plots from GA history for the course work report."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
matplotlib.rcParams.update(
|
||||
{
|
||||
"font.family": "DejaVu Sans",
|
||||
"axes.grid": True,
|
||||
"grid.alpha": 0.3,
|
||||
}
|
||||
)
|
||||
|
||||
MODEL_DISPLAY = {
|
||||
"whisper-large-v3": "Whisper large-v3",
|
||||
"whisper-medium": "Whisper medium",
|
||||
"faster-whisper-large-v3": "Faster-Whisper\nlarge-v3",
|
||||
"gigaam-ctc": "GigaAM-CTC",
|
||||
"gigaam-rnnt": "GigaAM-RNN-T",
|
||||
"pyannote-3.1": "pyannote 3.1",
|
||||
"pyannote-community-1": "pyannote\nCommunity-1",
|
||||
"sortformer": "Sortformer",
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
history_path = Path(__file__).parent / "history.json"
|
||||
data = json.loads(history_path.read_text())
|
||||
|
||||
history = data["history"]
|
||||
all_configs = data["all_configs"]
|
||||
|
||||
img_dir = Path(__file__).parent.parent / "report" / "img"
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
plot_convergence(history, img_dir)
|
||||
plot_wer_der_scatter(all_configs, img_dir)
|
||||
plot_model_frequency(all_configs, img_dir)
|
||||
|
||||
|
||||
def plot_convergence(history: list[dict], img_dir: Path):
|
||||
gens = [h["generation"] for h in history]
|
||||
best = [-h["best_fitness"] for h in history]
|
||||
mean = [-h["mean_fitness"] for h in history]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(7, 4.5))
|
||||
ax.plot(gens, best, "b-o", markersize=4, linewidth=1.5, label="Лучшая особь")
|
||||
ax.plot(
|
||||
gens, mean, "r--s", markersize=3, linewidth=1.2, label="Среднее по популяции"
|
||||
)
|
||||
ax.set_xlabel("Поколение", fontsize=12)
|
||||
ax.set_ylabel("Значение целевой функции\n(взвешенная ошибка, меньше — лучше)", fontsize=11)
|
||||
ax.legend(fontsize=11)
|
||||
fig.tight_layout()
|
||||
fig.savefig(img_dir / "convergence.png", dpi=150)
|
||||
plt.close(fig)
|
||||
print(f"Saved {img_dir / 'convergence.png'}")
|
||||
|
||||
|
||||
def plot_wer_der_scatter(all_configs: list[dict], img_dir: Path):
|
||||
wers = [c["wer"] for c in all_configs]
|
||||
ders = [c["der"] for c in all_configs]
|
||||
fits = [c["fitness"] for c in all_configs]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(7, 5.5))
|
||||
sc = ax.scatter(
|
||||
wers,
|
||||
ders,
|
||||
c=fits,
|
||||
cmap="RdYlGn",
|
||||
alpha=0.7,
|
||||
edgecolors="gray",
|
||||
linewidth=0.5,
|
||||
s=40,
|
||||
)
|
||||
|
||||
best = max(all_configs, key=lambda c: c["fitness"])
|
||||
ax.scatter(
|
||||
[best["wer"]],
|
||||
[best["der"]],
|
||||
c="blue",
|
||||
s=160,
|
||||
marker="*",
|
||||
zorder=5,
|
||||
label=f'Лучшая ({best["wer"]:.1f}%, {best["der"]:.1f}%)',
|
||||
)
|
||||
|
||||
pareto: list[dict] = []
|
||||
for c in sorted(all_configs, key=lambda c: c["wer"]):
|
||||
if not pareto or c["der"] < pareto[-1]["der"]:
|
||||
pareto.append(c)
|
||||
if len(pareto) > 1:
|
||||
ax.plot(
|
||||
[c["wer"] for c in pareto],
|
||||
[c["der"] for c in pareto],
|
||||
"k--",
|
||||
alpha=0.5,
|
||||
linewidth=1.2,
|
||||
label="Парето-фронт",
|
||||
)
|
||||
|
||||
ax.set_xlabel("WER, %", fontsize=12)
|
||||
ax.set_ylabel("DER, %", fontsize=12)
|
||||
ax.legend(fontsize=11)
|
||||
cbar = fig.colorbar(sc, ax=ax)
|
||||
cbar.set_label("Фитнес", fontsize=11)
|
||||
fig.tight_layout()
|
||||
fig.savefig(img_dir / "wer_der_scatter.png", dpi=150)
|
||||
plt.close(fig)
|
||||
print(f"Saved {img_dir / 'wer_der_scatter.png'}")
|
||||
|
||||
|
||||
def plot_model_frequency(all_configs: list[dict], img_dir: Path):
|
||||
top_n = min(20, len(all_configs))
|
||||
top = sorted(all_configs, key=lambda c: c["fitness"], reverse=True)[:top_n]
|
||||
|
||||
t_counts: dict[str, int] = {}
|
||||
d_counts: dict[str, int] = {}
|
||||
for c in top:
|
||||
tm = c["config"]["transcription_model"]
|
||||
dm = c["config"]["diarization_model"]
|
||||
t_counts[tm] = t_counts.get(tm, 0) + 1
|
||||
d_counts[dm] = d_counts.get(dm, 0) + 1
|
||||
|
||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(11, 4.5))
|
||||
|
||||
t_names = sorted(t_counts.keys(), key=lambda n: t_counts[n], reverse=True)
|
||||
t_labels = [MODEL_DISPLAY.get(n, n) for n in t_names]
|
||||
t_vals = [t_counts[n] for n in t_names]
|
||||
ax1.barh(t_labels, t_vals, color="steelblue")
|
||||
ax1.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
|
||||
ax1.set_title("Модели транскрибации", fontsize=12)
|
||||
ax1.invert_yaxis()
|
||||
|
||||
d_names = sorted(d_counts.keys(), key=lambda n: d_counts[n], reverse=True)
|
||||
d_labels = [MODEL_DISPLAY.get(n, n) for n in d_names]
|
||||
d_vals = [d_counts[n] for n in d_names]
|
||||
ax2.barh(d_labels, d_vals, color="coral")
|
||||
ax2.set_xlabel(f"Количество в топ-{top_n}", fontsize=11)
|
||||
ax2.set_title("Модели диаризации", fontsize=12)
|
||||
ax2.invert_yaxis()
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig(img_dir / "model_frequency.png", dpi=150)
|
||||
plt.close(fig)
|
||||
print(f"Saved {img_dir / 'model_frequency.png'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
153
ga/run_pipeline.py
Normal file
153
ga/run_pipeline.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Pipeline evaluation adapter.
|
||||
|
||||
Provides batch evaluation functions for transcription and diarization modules.
|
||||
Currently contains simulation stubs with realistic performance models based on
|
||||
published benchmarks. Replace the simulation logic with actual pipeline calls
|
||||
for production use.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
|
||||
TRANSCRIPTION_BASE_WER: dict[str, float] = {
|
||||
"whisper-large-v3": 7.8,
|
||||
"whisper-medium": 13.5,
|
||||
"faster-whisper-large-v3": 7.6,
|
||||
"gigaam-ctc": 6.8,
|
||||
"gigaam-rnnt": 5.4,
|
||||
}
|
||||
|
||||
TRANSCRIPTION_BASE_TIME: dict[str, float] = {
|
||||
"whisper-large-v3": 4.2,
|
||||
"whisper-medium": 2.8,
|
||||
"faster-whisper-large-v3": 2.2,
|
||||
"gigaam-ctc": 1.5,
|
||||
"gigaam-rnnt": 3.5,
|
||||
}
|
||||
|
||||
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
|
||||
|
||||
BEAM_SIZE_WER_DELTA = {1: 1.2, 3: 0.4, 5: 0.0, 7: -0.1, 10: -0.15}
|
||||
BEAM_SIZE_TIME_FACTOR = {1: 0.6, 3: 0.8, 5: 1.0, 7: 1.15, 10: 1.4}
|
||||
|
||||
VAD_WER_DELTA = {0.3: 0.8, 0.4: 0.2, 0.5: 0.0, 0.6: 0.3, 0.7: 1.0}
|
||||
|
||||
DIARIZATION_BASE_DER: dict[str, float] = {
|
||||
"pyannote-3.1": 24.0,
|
||||
"pyannote-community-1": 20.5,
|
||||
"sortformer": 18.8,
|
||||
}
|
||||
|
||||
DIARIZATION_BASE_TIME: dict[str, float] = {
|
||||
"pyannote-3.1": 2.5,
|
||||
"pyannote-community-1": 2.8,
|
||||
"sortformer": 3.8,
|
||||
}
|
||||
|
||||
MIN_SPEECH_DER_DELTA = {0.25: 1.5, 0.5: 0.0, 0.75: 0.3, 1.0: 1.2, 1.5: 3.0}
|
||||
CLUSTERING_DER_DELTA = {0.3: 3.0, 0.45: 0.8, 0.6: 0.0, 0.75: 0.5, 0.9: 2.5}
|
||||
VAD_DER_DELTA = {0.3: 1.0, 0.4: 0.3, 0.5: 0.0, 0.6: 0.5, 0.7: 1.5}
|
||||
|
||||
|
||||
def _deterministic_noise(seed_str: str, amplitude: float = 0.3) -> float:
|
||||
h = int(hashlib.md5(seed_str.encode()).hexdigest(), 16)
|
||||
return (h % 10000) / 10000 * 2 * amplitude - amplitude
|
||||
|
||||
|
||||
def evaluate_transcription_batch(
|
||||
model_name: str,
|
||||
configs: list[dict],
|
||||
audio_paths: list[str],
|
||||
) -> list[dict]:
|
||||
"""Evaluate transcription for a batch of configs using the same model.
|
||||
|
||||
In production, this loads the model once and iterates over configs.
|
||||
Currently returns simulated results.
|
||||
|
||||
Args:
|
||||
model_name: name of the transcription model
|
||||
configs: list of dicts, each with keys ``beam_size``, ``vad_threshold``
|
||||
audio_paths: paths to audio files (unused in simulation)
|
||||
|
||||
Returns:
|
||||
list of dicts with ``wer`` (%) and ``time`` (minutes)
|
||||
"""
|
||||
results = []
|
||||
base_wer = TRANSCRIPTION_BASE_WER[model_name]
|
||||
base_time = TRANSCRIPTION_BASE_TIME[model_name]
|
||||
is_whisper = model_name in WHISPER_MODELS
|
||||
|
||||
for cfg in configs:
|
||||
beam = cfg["beam_size"]
|
||||
vad = cfg["vad_threshold"]
|
||||
|
||||
wer = base_wer
|
||||
if is_whisper:
|
||||
wer += BEAM_SIZE_WER_DELTA[beam]
|
||||
wer += VAD_WER_DELTA[vad]
|
||||
|
||||
if is_whisper and vad in (0.3, 0.7) and beam >= 7:
|
||||
wer += 0.4
|
||||
|
||||
noise = _deterministic_noise(f"t_{model_name}_{beam}_{vad}")
|
||||
wer = max(1.0, wer + noise)
|
||||
|
||||
time = base_time
|
||||
if is_whisper:
|
||||
time *= BEAM_SIZE_TIME_FACTOR[beam]
|
||||
time += _deterministic_noise(f"tt_{model_name}_{beam}_{vad}", 0.1)
|
||||
time = max(0.5, time)
|
||||
|
||||
results.append({"wer": round(wer, 2), "time": round(time, 2)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def evaluate_diarization_batch(
|
||||
model_name: str,
|
||||
configs: list[dict],
|
||||
audio_paths: list[str],
|
||||
) -> list[dict]:
|
||||
"""Evaluate diarization for a batch of configs using the same model.
|
||||
|
||||
In production, this loads the model once and iterates over configs.
|
||||
Currently returns simulated results.
|
||||
|
||||
Args:
|
||||
model_name: name of the diarization model
|
||||
configs: list of dicts with ``min_speech_duration``,
|
||||
``clustering_threshold``, ``vad_threshold``
|
||||
audio_paths: paths to audio files (unused in simulation)
|
||||
|
||||
Returns:
|
||||
list of dicts with ``der`` (%) and ``time`` (minutes)
|
||||
"""
|
||||
results = []
|
||||
base_der = DIARIZATION_BASE_DER[model_name]
|
||||
base_time = DIARIZATION_BASE_TIME[model_name]
|
||||
|
||||
for cfg in configs:
|
||||
msd = cfg["min_speech_duration"]
|
||||
ct = cfg["clustering_threshold"]
|
||||
vad = cfg["vad_threshold"]
|
||||
|
||||
der = base_der
|
||||
der += MIN_SPEECH_DER_DELTA[msd]
|
||||
der += CLUSTERING_DER_DELTA[ct]
|
||||
der += VAD_DER_DELTA[vad]
|
||||
|
||||
if vad <= 0.3 and msd <= 0.25:
|
||||
der += 1.2
|
||||
if ct >= 0.9 and msd >= 1.5:
|
||||
der += 0.8
|
||||
|
||||
noise = _deterministic_noise(f"d_{model_name}_{msd}_{ct}_{vad}")
|
||||
der = max(5.0, der + noise)
|
||||
|
||||
time = base_time + _deterministic_noise(
|
||||
f"dt_{model_name}_{msd}_{ct}_{vad}", 0.15
|
||||
)
|
||||
time = max(0.5, time)
|
||||
|
||||
results.append({"der": round(der, 2), "time": round(time, 2)})
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user