#!/usr/bin/env python3 """Genetic algorithm for optimizing meeting transcription+diarization pipeline. Searches over a mixed discrete configuration space of transcription and diarization models and their parameters. Uses module-level caching and batch scheduling grouped by model to minimize redundant computations. """ import json import random from collections import defaultdict from dataclasses import dataclass from pathlib import Path import run_pipeline # --------------------------------------------------------------------------- # Configuration space # --------------------------------------------------------------------------- TRANSCRIPTION_MODELS = [ "whisper-large-v3", "whisper-medium", "faster-whisper-large-v3", "gigaam-ctc", "gigaam-rnnt", ] BEAM_SIZES = [1, 3, 5, 7, 10] VAD_THRESHOLDS = [0.3, 0.4, 0.5, 0.6, 0.7] DIARIZATION_MODELS = ["pyannote-3.1", "pyannote-community-1", "sortformer"] MIN_SPEECH_DURATIONS = [0.25, 0.5, 0.75, 1.0, 1.5] CLUSTERING_THRESHOLDS = [0.3, 0.45, 0.6, 0.75, 0.9] WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"} GENES: list[tuple[str, list]] = [ ("transcription_model", TRANSCRIPTION_MODELS), ("beam_size", BEAM_SIZES), ("vad_threshold", VAD_THRESHOLDS), ("diarization_model", DIARIZATION_MODELS), ("min_speech_duration", MIN_SPEECH_DURATIONS), ("clustering_threshold", CLUSTERING_THRESHOLDS), ] # --------------------------------------------------------------------------- # GA hyper-parameters # --------------------------------------------------------------------------- POPULATION_SIZE = 15 NUM_GENERATIONS = 25 TOURNAMENT_SIZE = 3 MUTATION_PROB = 0.15 ELITE_COUNT = 2 ALPHA = 0.4 # WER weight BETA = 0.4 # DER weight GAMMA = 0.2 # time weight # --------------------------------------------------------------------------- # Chromosome # --------------------------------------------------------------------------- @dataclass class Chromosome: genes: list[int] fitness: float | None = None wer: float | None = None der: float | None = None time_min: float | None = None def to_config(self) -> dict: return { name: values[self.genes[i]] for i, (name, values) in enumerate(GENES) } def transcription_key(self) -> tuple: cfg = self.to_config() model = cfg["transcription_model"] beam = cfg["beam_size"] if model in WHISPER_MODELS else 1 return (model, beam, cfg["vad_threshold"]) def diarization_key(self) -> tuple: cfg = self.to_config() return ( cfg["diarization_model"], cfg["min_speech_duration"], cfg["clustering_threshold"], cfg["vad_threshold"], ) def copy(self) -> "Chromosome": return Chromosome(genes=self.genes.copy()) # --------------------------------------------------------------------------- # Module-level cache # --------------------------------------------------------------------------- class Cache: def __init__(self, cache_dir: Path): self.cache_dir = cache_dir self.cache_dir.mkdir(parents=True, exist_ok=True) self.transcription: dict[str, dict] = {} self.diarization: dict[str, dict] = {} self._load() def _load(self): for name in ("transcription", "diarization"): path = self.cache_dir / f"{name}.json" if path.exists(): setattr(self, name, json.loads(path.read_text())) def save(self): for name in ("transcription", "diarization"): path = self.cache_dir / f"{name}.json" path.write_text(json.dumps(getattr(self, name), indent=2)) def get_transcription(self, key: tuple) -> dict | None: return self.transcription.get(str(key)) def set_transcription(self, key: tuple, result: dict): self.transcription[str(key)] = result def get_diarization(self, key: tuple) -> dict | None: return self.diarization.get(str(key)) def set_diarization(self, key: tuple, result: dict): self.diarization[str(key)] = result # --------------------------------------------------------------------------- # GA operators # --------------------------------------------------------------------------- def random_chromosome() -> Chromosome: return Chromosome(genes=[random.randint(0, len(v) - 1) for _, v in GENES]) def tournament_select(population: list[Chromosome]) -> Chromosome: candidates = random.sample(population, TOURNAMENT_SIZE) return max(candidates, key=lambda c: c.fitness) def crossover(p1: Chromosome, p2: Chromosome) -> Chromosome: return Chromosome( genes=[random.choice([g1, g2]) for g1, g2 in zip(p1.genes, p2.genes)] ) def mutate(chrom: Chromosome) -> Chromosome: genes = chrom.genes.copy() for i, (_, values) in enumerate(GENES): if random.random() < MUTATION_PROB: if len(values) > 2 and random.random() < 0.7: delta = random.choice([-1, 1]) genes[i] = max(0, min(len(values) - 1, genes[i] + delta)) else: genes[i] = random.randint(0, len(values) - 1) return Chromosome(genes=genes) def compute_fitness(wer: float, der: float, time_min: float) -> float: return -(ALPHA * wer + BETA * der + GAMMA * time_min) # --------------------------------------------------------------------------- # Batch scheduler # --------------------------------------------------------------------------- def schedule_evaluations( population: list[Chromosome], cache: Cache, audio_paths: list[str] ) -> int: """Evaluate chromosomes using cache and batching by model. 1. Collect unique uncached transcription and diarization configs. 2. Group them by model so the pipeline loads each model only once. 3. Store results in cache and assemble fitness values. Returns the number of new (uncached) module evaluations performed. """ uncached_t: dict[str, list[tuple[tuple, dict]]] = defaultdict(list) uncached_d: dict[str, list[tuple[tuple, dict]]] = defaultdict(list) seen_t: set[str] = set() seen_d: set[str] = set() for chrom in population: cfg = chrom.to_config() t_key = chrom.transcription_key() t_key_s = str(t_key) if cache.get_transcription(t_key) is None and t_key_s not in seen_t: seen_t.add(t_key_s) model = cfg["transcription_model"] beam = cfg["beam_size"] if model in WHISPER_MODELS else 1 uncached_t[model].append( (t_key, {"beam_size": beam, "vad_threshold": cfg["vad_threshold"]}) ) d_key = chrom.diarization_key() d_key_s = str(d_key) if cache.get_diarization(d_key) is None and d_key_s not in seen_d: seen_d.add(d_key_s) uncached_d[cfg["diarization_model"]].append( ( d_key, { "min_speech_duration": cfg["min_speech_duration"], "clustering_threshold": cfg["clustering_threshold"], "vad_threshold": cfg["vad_threshold"], }, ) ) new_evals = 0 for model, items in uncached_t.items(): configs = [c for _, c in items] results = run_pipeline.evaluate_transcription_batch( model, configs, audio_paths ) for (key, _), result in zip(items, results): cache.set_transcription(key, result) new_evals += 1 for model, items in uncached_d.items(): configs = [c for _, c in items] results = run_pipeline.evaluate_diarization_batch( model, configs, audio_paths ) for (key, _), result in zip(items, results): cache.set_diarization(key, result) new_evals += 1 if new_evals > 0: cache.save() for chrom in population: t_res = cache.get_transcription(chrom.transcription_key()) d_res = cache.get_diarization(chrom.diarization_key()) chrom.wer = t_res["wer"] chrom.der = d_res["der"] chrom.time_min = t_res["time"] + d_res["time"] chrom.fitness = compute_fitness(chrom.wer, chrom.der, chrom.time_min) return new_evals # --------------------------------------------------------------------------- # Main GA loop # --------------------------------------------------------------------------- def run_ga(audio_paths: list[str] | None = None, seed: int = 42) -> list[dict]: random.seed(seed) if audio_paths is None: audio_paths = [] cache = Cache(Path(__file__).parent / "cache") history: list[dict] = [] all_configs: list[dict] = [] seen_genes: set[tuple[int, ...]] = set() total_evals = 0 population = [random_chromosome() for _ in range(POPULATION_SIZE)] new_evals = schedule_evaluations(population, cache, audio_paths) total_evals += new_evals for gen in range(NUM_GENERATIONS): population.sort(key=lambda c: c.fitness, reverse=True) for chrom in population: key = tuple(chrom.genes) if key not in seen_genes: seen_genes.add(key) all_configs.append( { "config": chrom.to_config(), "wer": chrom.wer, "der": chrom.der, "time": chrom.time_min, "fitness": chrom.fitness, "generation": gen, } ) best = population[0] mean_fit = sum(c.fitness for c in population) / len(population) history.append( { "generation": gen, "best_fitness": round(best.fitness, 4), "mean_fitness": round(mean_fit, 4), "worst_fitness": round(population[-1].fitness, 4), "best_config": best.to_config(), "best_wer": best.wer, "best_der": best.der, "best_time": best.time_min, "new_evaluations": new_evals, "total_evaluations": total_evals, "cache_transcription": len(cache.transcription), "cache_diarization": len(cache.diarization), } ) print( f"Gen {gen:3d} | best={best.fitness:.3f} mean={mean_fit:.3f} | " f"WER={best.wer:.1f}% DER={best.der:.1f}% | " f"new={new_evals} cache_t={len(cache.transcription)} " f"cache_d={len(cache.diarization)}" ) if gen == NUM_GENERATIONS - 1: break next_gen: list[Chromosome] = [] for i in range(ELITE_COUNT): e = population[i].copy() e.fitness = population[i].fitness e.wer = population[i].wer e.der = population[i].der e.time_min = population[i].time_min next_gen.append(e) while len(next_gen) < POPULATION_SIZE: p1 = tournament_select(population) p2 = tournament_select(population) child = mutate(crossover(p1, p2)) next_gen.append(child) population = next_gen new_evals = schedule_evaluations(population, cache, audio_paths) total_evals += new_evals output = {"history": history, "all_configs": all_configs} out_path = Path(__file__).parent / "history.json" out_path.write_text(json.dumps(output, indent=2, ensure_ascii=False)) print(f"\nResults saved to {out_path}") population.sort(key=lambda c: c.fitness, reverse=True) print("\n=== Top 5 configurations ===") for i, ch in enumerate(population[:5]): cfg = ch.to_config() print( f"\n#{i + 1}: fitness={ch.fitness:.3f} " f"WER={ch.wer:.2f}% DER={ch.der:.2f}% time={ch.time_min:.2f}min" ) for k, v in cfg.items(): print(f" {k}: {v}") return history if __name__ == "__main__": run_ga()