"""Lightweight Wav2Vec‑based English accent ID.""" from functools import lru_cache from typing import Dict import torch import torchaudio from transformers import AutoProcessor, AutoModelForAudioClassification import numpy as np from .config import BASE_MODEL_ID, SEGMENT_SECONDS class AccentClassifier: def __init__(self, device: str | None = None): self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.processor = AutoProcessor.from_pretrained(BASE_MODEL_ID) self.model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL_ID).to(self.device) self.labels = self.model.config.id2label @torch.inference_mode() def classify(self, wav: str) -> Dict[str, str | int]: wav_arr, sr = torchaudio.load(wav) wav_arr = wav_arr.squeeze(0).numpy()[: sr * SEGMENT_SECONDS] inp = self.processor(wav_arr, sampling_rate=sr, return_tensors="pt") inp = {k: v.to(self.device) for k, v in inp.items()} logits = self.model(**inp).logits[0] probs = torch.softmax(logits, dim=-1).cpu().numpy() idx = int(np.argmax(probs)) return {"accent": self.labels[idx], "confidence": int(probs[idx] * 100)} @lru_cache(maxsize=1) def get_classifier() -> AccentClassifier: return AccentClassifier()