Spaces:
Runtime error
Runtime error
File size: 1,794 Bytes
c23173c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# src/transcription.py
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch
import torchaudio
class SpeechEncoder:
def __init__(
self,
model_name: str = "jonatasgrosman/wav2vec2-large-xlsr-53-french",
#model_name: str = "alec228/audio-sentiment/tree/main/wav2vec2",
cache_dir: str = "./models"
):
# Processor pour prétraiter l'audio
self.processor = Wav2Vec2Processor.from_pretrained(
model_name, cache_dir=cache_dir
)
# Modèle de base (sans tête CTC)
self.model = Wav2Vec2Model.from_pretrained(
model_name, cache_dir=cache_dir
)
def extract_features(self, audio_path: str) -> torch.Tensor:
"""
Charge un fichier audio, le resample à 16 kHz, convertit en mono,
et renvoie la représentation vectorielle moyenne sur la séquence.
"""
# 1. Chargement
waveform, sample_rate = torchaudio.load(audio_path)
# 2. Resample si nécessaire
if sample_rate != 16000:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=16000
)(waveform)
# 3. Passage en mono
if waveform.size(0) > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# 4. Prétraitement pour le modèle
inputs = self.processor(
waveform.squeeze().numpy(),
sampling_rate=16000,
return_tensors="pt",
padding=True
)
# 5. Extraction sans gradient
with torch.no_grad():
outputs = self.model(**inputs)
# 6. Moyenne temporelle des embeddings
return outputs.last_hidden_state.mean(dim=1) # shape: [batch, hidden_size]
|