Spaces:
Build error
Build error
"""Lightweight Wav2Vec‑based English accent ID.""" | |
from functools import lru_cache | |
from typing import Dict | |
import torch | |
import torchaudio | |
from transformers import AutoProcessor, AutoModelForAudioClassification | |
import numpy as np | |
from .config import BASE_MODEL_ID, SEGMENT_SECONDS | |
class AccentClassifier: | |
def __init__(self, device: str | None = None): | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.processor = AutoProcessor.from_pretrained(BASE_MODEL_ID) | |
self.model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL_ID).to(self.device) | |
self.labels = self.model.config.id2label | |
def classify(self, wav: str) -> Dict[str, str | int]: | |
wav_arr, sr = torchaudio.load(wav) | |
wav_arr = wav_arr.squeeze(0).numpy()[: sr * SEGMENT_SECONDS] | |
inp = self.processor(wav_arr, sampling_rate=sr, return_tensors="pt") | |
inp = {k: v.to(self.device) for k, v in inp.items()} | |
logits = self.model(**inp).logits[0] | |
probs = torch.softmax(logits, dim=-1).cpu().numpy() | |
idx = int(np.argmax(probs)) | |
return {"accent": self.labels[idx], "confidence": int(probs[idx] * 100)} | |
def get_classifier() -> AccentClassifier: | |
return AccentClassifier() |