File size: 1,313 Bytes
3859913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""Lightweight Wav2Vec‑based English accent ID."""
from functools import lru_cache
from typing import Dict

import torch
import torchaudio
from transformers import AutoProcessor, AutoModelForAudioClassification
import numpy as np

from .config import BASE_MODEL_ID, SEGMENT_SECONDS

class AccentClassifier:
    def __init__(self, device: str | None = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.processor = AutoProcessor.from_pretrained(BASE_MODEL_ID)
        self.model = AutoModelForAudioClassification.from_pretrained(BASE_MODEL_ID).to(self.device)
        self.labels = self.model.config.id2label

    @torch.inference_mode()
    def classify(self, wav: str) -> Dict[str, str | int]:
        wav_arr, sr = torchaudio.load(wav)
        wav_arr = wav_arr.squeeze(0).numpy()[: sr * SEGMENT_SECONDS]
        inp = self.processor(wav_arr, sampling_rate=sr, return_tensors="pt")
        inp = {k: v.to(self.device) for k, v in inp.items()}
        logits = self.model(**inp).logits[0]
        probs = torch.softmax(logits, dim=-1).cpu().numpy()
        idx = int(np.argmax(probs))
        return {"accent": self.labels[idx], "confidence": int(probs[idx] * 100)}

@lru_cache(maxsize=1)
def get_classifier() -> AccentClassifier:
    return AccentClassifier()