translationLeaderBoard / evaluator.py
yigagilbert's picture
Update evaluator.py
08b976c verified
# evaluator.py
import numpy as np
from collections import defaultdict
from datasets import load_dataset
from transformers import pipeline
import os
from sacrebleu.metrics import BLEU, CHRF
from rouge_score import rouge_scorer
import Levenshtein
from transformers.models.whisper.english_normalizer import BasicTextNormalizer
def calculate_metrics(reference: str, prediction: str) -> dict:
"""
Compute a suite of translation / generation metrics:
- BLEU
- chrF
- CER (character error rate)
- WER (word error rate)
- length ratio
- ROUGE-1 & ROUGE-L
- a combined quality_score
"""
# BLEU
bleu = BLEU(effective_order=True)
bleu_score = bleu.sentence_score(prediction, [reference]).score
# chrF
chrf = CHRF()
chrf_score = chrf.sentence_score(prediction, [reference]).score / 100.0
# Character error rate
cer = Levenshtein.distance(reference, prediction) / max(len(reference), 1)
# Word error rate
ref_words = reference.split()
pred_words = prediction.split()
wer = Levenshtein.distance(ref_words, pred_words) / max(len(ref_words), 1)
# Length ratio
len_ratio = len(prediction) / max(len(reference), 1)
# ROUGE
rouge_scores = {}
try:
scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True)
rouge_scores = scorer.score(reference, prediction)
rouge_1 = rouge_scores["rouge1"].fmeasure
rouge_L = rouge_scores["rougeL"].fmeasure
except Exception:
rouge_1 = rouge_L = 0.0
# Combined quality
try:
quality_score = (
(bleu_score / 100)
+ chrf_score
+ (1 - cer)
+ (1 - wer)
+ rouge_1
+ rouge_L
) / 6
except Exception:
quality_score = (
(bleu_score / 100) + chrf_score + (1 - cer) + (1 - wer)
) / 4
return {
"bleu": bleu_score,
"chrf": chrf_score,
"cer": cer,
"wer": wer,
"len_ratio": len_ratio,
"rouge1": rouge_1,
"rougeL": rouge_L,
"quality_score": quality_score,
}
def evaluate_model(
model_name: str,
dataset_name: str,
split: str = "test",
text_field: str = "source",
target_field: str = "target",
task: str = "translation", # or "automatic-speech-recognition", etc.
device: int = 0,
) -> dict:
"""
Load your dataset, run inference via a 🤗 pipeline, and compute metrics
grouped by language‐pair (if present) plus overall averages.
Returns a dict of shape:
{
"<src>_to_<tgt>": {<metric1>: val, ...},
...,
"averages": {<metric1>: val, ...}
}
"""
# Get Hugging Face token from environment variable
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise ValueError("Hugging Face token (HF_TOKEN) is not set. Please set it as an environment variable.")
# 1) load test split
ds = load_dataset(dataset_name, split=split, use_auth_token=hf_token)
# 2) build pipeline
nlp = pipeline(task, model=model_name, device=device)
# 3) run inference
normalizer = BasicTextNormalizer()
translations = []
for ex in tqdm(ds, desc=f"Eval {model_name}"):
src = ex[text_field]
tgt = ex[target_field]
pred = nlp(src)[0].get("translation_text", nlp(src)[0].get("text", ""))
translations.append({
"source": src,
"target": tgt,
"prediction": pred,
# Optional language metadata:
"source.language": ex.get("source.language", ""),
"target.language": ex.get("target.language", "")
})
# 4) group by language‐pair
subsets = defaultdict(list)
for ex in translations:
key = (
f"{ex['source.language']}_to_{ex['target.language']}"
if ex["source.language"] and ex["target.language"]
else "default"
)
subsets[key].append(ex)
# 5) compute metrics per subset
results = {}
for subset, examples in subsets.items():
# collect metrics lists
agg = defaultdict(list)
for ex in examples:
ref = normalizer(ex["target"])
pred = normalizer(ex["prediction"])
m = calculate_metrics(ref, pred)
for k, v in m.items():
agg[k].append(v)
# take mean
results[subset] = {k: float(np.mean(vs)) for k, vs in agg.items()}
# 6) overall averages
all_metrics = list(results.values())
avg = {}
for k in all_metrics[0].keys():
avg[k] = float(np.mean([m[k] for m in all_metrics]))
results["averages"] = avg
return results
if __name__ == "__main__":
# simple test
import json
out = evaluate_model(
model_name="facebook/wmt19-en-de",
dataset_name="wmt19",
split="test",
)
print(json.dumps(out, indent=2))