Challenge_Task / app /accent_classifier.py
abdullah0101's picture
Add full application code and deps
3859913
"""Lightweight Wav2Vec‑based English accent ID."""
from functools import lru_cache
from typing import Dict
import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForAudioClassification
import numpy as np
from .config import BASE_MODEL_ID, SEGMENT_SECONDS
class AccentClassifier:
def __init__(self, device: str | None = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.processor = AutoProcessor.from_pretrained(BASE_MODEL_ID)
self.model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL_ID).to(self.device)
self.labels = self.model.config.id2label
@torch.inference_mode()
def classify(self, wav: str) -> Dict[str, str | int]:
wav_arr, sr = torchaudio.load(wav)
wav_arr = wav_arr.squeeze(0).numpy()[: sr * SEGMENT_SECONDS]
inp = self.processor(wav_arr, sampling_rate=sr, return_tensors="pt")
inp = {k: v.to(self.device) for k, v in inp.items()}
logits = self.model(**inp).logits[0]
probs = torch.softmax(logits, dim=-1).cpu().numpy()
idx = int(np.argmax(probs))
return {"accent": self.labels[idx], "confidence": int(probs[idx] * 100)}
@lru_cache(maxsize=1)
def get_classifier() -> AccentClassifier:
return AccentClassifier()