га
This commit is contained in:
153
ga/run_pipeline.py
Normal file
153
ga/run_pipeline.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Pipeline evaluation adapter.
|
||||
|
||||
Provides batch evaluation functions for transcription and diarization modules.
|
||||
Currently contains simulation stubs with realistic performance models based on
|
||||
published benchmarks. Replace the simulation logic with actual pipeline calls
|
||||
for production use.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
|
||||
TRANSCRIPTION_BASE_WER: dict[str, float] = {
|
||||
"whisper-large-v3": 7.8,
|
||||
"whisper-medium": 13.5,
|
||||
"faster-whisper-large-v3": 7.6,
|
||||
"gigaam-ctc": 6.8,
|
||||
"gigaam-rnnt": 5.4,
|
||||
}
|
||||
|
||||
TRANSCRIPTION_BASE_TIME: dict[str, float] = {
|
||||
"whisper-large-v3": 4.2,
|
||||
"whisper-medium": 2.8,
|
||||
"faster-whisper-large-v3": 2.2,
|
||||
"gigaam-ctc": 1.5,
|
||||
"gigaam-rnnt": 3.5,
|
||||
}
|
||||
|
||||
WHISPER_MODELS = {"whisper-large-v3", "whisper-medium", "faster-whisper-large-v3"}
|
||||
|
||||
BEAM_SIZE_WER_DELTA = {1: 1.2, 3: 0.4, 5: 0.0, 7: -0.1, 10: -0.15}
|
||||
BEAM_SIZE_TIME_FACTOR = {1: 0.6, 3: 0.8, 5: 1.0, 7: 1.15, 10: 1.4}
|
||||
|
||||
VAD_WER_DELTA = {0.3: 0.8, 0.4: 0.2, 0.5: 0.0, 0.6: 0.3, 0.7: 1.0}
|
||||
|
||||
DIARIZATION_BASE_DER: dict[str, float] = {
|
||||
"pyannote-3.1": 24.0,
|
||||
"pyannote-community-1": 20.5,
|
||||
"sortformer": 18.8,
|
||||
}
|
||||
|
||||
DIARIZATION_BASE_TIME: dict[str, float] = {
|
||||
"pyannote-3.1": 2.5,
|
||||
"pyannote-community-1": 2.8,
|
||||
"sortformer": 3.8,
|
||||
}
|
||||
|
||||
MIN_SPEECH_DER_DELTA = {0.25: 1.5, 0.5: 0.0, 0.75: 0.3, 1.0: 1.2, 1.5: 3.0}
|
||||
CLUSTERING_DER_DELTA = {0.3: 3.0, 0.45: 0.8, 0.6: 0.0, 0.75: 0.5, 0.9: 2.5}
|
||||
VAD_DER_DELTA = {0.3: 1.0, 0.4: 0.3, 0.5: 0.0, 0.6: 0.5, 0.7: 1.5}
|
||||
|
||||
|
||||
def _deterministic_noise(seed_str: str, amplitude: float = 0.3) -> float:
|
||||
h = int(hashlib.md5(seed_str.encode()).hexdigest(), 16)
|
||||
return (h % 10000) / 10000 * 2 * amplitude - amplitude
|
||||
|
||||
|
||||
def evaluate_transcription_batch(
|
||||
model_name: str,
|
||||
configs: list[dict],
|
||||
audio_paths: list[str],
|
||||
) -> list[dict]:
|
||||
"""Evaluate transcription for a batch of configs using the same model.
|
||||
|
||||
In production, this loads the model once and iterates over configs.
|
||||
Currently returns simulated results.
|
||||
|
||||
Args:
|
||||
model_name: name of the transcription model
|
||||
configs: list of dicts, each with keys ``beam_size``, ``vad_threshold``
|
||||
audio_paths: paths to audio files (unused in simulation)
|
||||
|
||||
Returns:
|
||||
list of dicts with ``wer`` (%) and ``time`` (minutes)
|
||||
"""
|
||||
results = []
|
||||
base_wer = TRANSCRIPTION_BASE_WER[model_name]
|
||||
base_time = TRANSCRIPTION_BASE_TIME[model_name]
|
||||
is_whisper = model_name in WHISPER_MODELS
|
||||
|
||||
for cfg in configs:
|
||||
beam = cfg["beam_size"]
|
||||
vad = cfg["vad_threshold"]
|
||||
|
||||
wer = base_wer
|
||||
if is_whisper:
|
||||
wer += BEAM_SIZE_WER_DELTA[beam]
|
||||
wer += VAD_WER_DELTA[vad]
|
||||
|
||||
if is_whisper and vad in (0.3, 0.7) and beam >= 7:
|
||||
wer += 0.4
|
||||
|
||||
noise = _deterministic_noise(f"t_{model_name}_{beam}_{vad}")
|
||||
wer = max(1.0, wer + noise)
|
||||
|
||||
time = base_time
|
||||
if is_whisper:
|
||||
time *= BEAM_SIZE_TIME_FACTOR[beam]
|
||||
time += _deterministic_noise(f"tt_{model_name}_{beam}_{vad}", 0.1)
|
||||
time = max(0.5, time)
|
||||
|
||||
results.append({"wer": round(wer, 2), "time": round(time, 2)})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def evaluate_diarization_batch(
|
||||
model_name: str,
|
||||
configs: list[dict],
|
||||
audio_paths: list[str],
|
||||
) -> list[dict]:
|
||||
"""Evaluate diarization for a batch of configs using the same model.
|
||||
|
||||
In production, this loads the model once and iterates over configs.
|
||||
Currently returns simulated results.
|
||||
|
||||
Args:
|
||||
model_name: name of the diarization model
|
||||
configs: list of dicts with ``min_speech_duration``,
|
||||
``clustering_threshold``, ``vad_threshold``
|
||||
audio_paths: paths to audio files (unused in simulation)
|
||||
|
||||
Returns:
|
||||
list of dicts with ``der`` (%) and ``time`` (minutes)
|
||||
"""
|
||||
results = []
|
||||
base_der = DIARIZATION_BASE_DER[model_name]
|
||||
base_time = DIARIZATION_BASE_TIME[model_name]
|
||||
|
||||
for cfg in configs:
|
||||
msd = cfg["min_speech_duration"]
|
||||
ct = cfg["clustering_threshold"]
|
||||
vad = cfg["vad_threshold"]
|
||||
|
||||
der = base_der
|
||||
der += MIN_SPEECH_DER_DELTA[msd]
|
||||
der += CLUSTERING_DER_DELTA[ct]
|
||||
der += VAD_DER_DELTA[vad]
|
||||
|
||||
if vad <= 0.3 and msd <= 0.25:
|
||||
der += 1.2
|
||||
if ct >= 0.9 and msd >= 1.5:
|
||||
der += 0.8
|
||||
|
||||
noise = _deterministic_noise(f"d_{model_name}_{msd}_{ct}_{vad}")
|
||||
der = max(5.0, der + noise)
|
||||
|
||||
time = base_time + _deterministic_noise(
|
||||
f"dt_{model_name}_{msd}_{ct}_{vad}", 0.15
|
||||
)
|
||||
time = max(0.5, time)
|
||||
|
||||
results.append({"der": round(der, 2), "time": round(time, 2)})
|
||||
|
||||
return results
|
||||
Reference in New Issue
Block a user