AudioTranscriber / hf_transcriber.py
PatienceIzere's picture
Update hf_transcriber.py
6e4378a verified
import os
import torch
import numpy as np
import librosa
import torchaudio
from typing import Dict, Any, Optional, Union, Tuple, List
import warnings
from transformers import (
AutoModelForCTC,
AutoProcessor,
pipeline,
SpeechT5Processor,
SpeechT5ForSpeechToText,
SpeechT5HifiGan
)
# Suppress specific warnings
warnings.filterwarnings("ignore", message=".*gradient_checkpointing*.")
warnings.filterwarnings("ignore", message="Using the model-agnostic default `max_length`")
warnings.filterwarnings("ignore", message="You are using the default legacy behaviour")
class HFTranscriber:
def __init__(self, model_name: str = "facebook/wav2vec2-base-960h"):
"""
Initialize the Hugging Face transcriber with a pre-trained model.
Args:
model_name (str): Name of the Hugging Face model to use for transcription.
Supported models:
- "facebook/wav2vec2-base-960h" (default)
- "openai/whisper-small"
- "microsoft/speecht5_asr"
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_name = model_name
self.processor = None
self.model = None
self.vocoder = None
self.is_speecht5 = "speecht5" in model_name.lower()
self.is_whisper = "whisper" in model_name.lower()
self._load_model()
def _load_model(self):
"""Load the model and processor based on the model type with authentication."""
try:
# Try to get Hugging Face token from environment
hf_token = (os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN') or (st.secrets.get('HUGGINGFACE_TOKEN') if 'secrets' in globals() and hasattr(st.secrets, 'get') else None) or (st.secrets.get('HF_TOKEN') if 'secrets' in globals() and hasattr(st.secrets, 'get') else None))
if not hf_token:
st.sidebar.error("No Hugging Face token found. Using public access (rate limited).Please add it to your environment variables as HUGGINGFACE_TOKEN or HF_TOKEN.")
#Configure headers for API requests
headers ={}
if hf_token:
headers['Authorization'] = f'Bearer {hf_token}'
#Configure model loading parameters
load_kwargs = {'token': hf_token, 'use_auth_token': hf_token, 'local_files_only': False, 'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
#Remove None values
load_kwargs = {k: v for k, v in load_kwargs.items() if v is not None}
#Rest of model loading code.....
if self.is_speecht5:
# Load SpeechT5 model and processor
self.processor = SpeechT5Processor.from_pretrained(
self.model_name,
**load_kwargs
)
self.model = SpeechT5ForSpeechToText.from_pretrained(
self.model_name,
**load_kwargs
)
self.vocoder = SpeechT5HifiGan.from_pretrained(
"microsoft/speecht5_hifigan",
**load_kwargs
)
self.model = self.model.to(self.device)
self.vocoder = self.vocoder.to(self.device)
self.model.eval()
self.vocoder.eval()
elif self.is_whisper:
# For whisper, we'll use the pipeline with the token
self.model = pipeline(
"automatic-speech-recognition",
model=self.model_name,
token=hf_token, # Pass token directly
device=0 if self.device == "cuda" else -1
)
self.processor = None # Not needed when using pipeline
else:
# Load wav2vec2 model and processor
self.processor = AutoProcessor.from_pretrained(
self.model_name,
**load_kwargs
)
self.model = AutoModelForCTC.from_pretrained(
self.model_name,
**load_kwargs
)
self.model = self.model.to(self.device)
self.model.eval()
except Exception as e:
error_msg = str(e)
if "401" in error_msg or "401" in str(e.__cause__):
raise Exception(
"Authentication failed. Please check your Hugging Face token.\n"
"1. Get your token from https://huggingface.co/settings/tokens\n"
"2. Add it to your environment variables as HUGGINGFACE_TOKEN"
) from e
elif "404" in error_msg:
raise Exception(
f"Model {self.model_name} not found. Please check the model name."
) from e
else:
raise Exception(
f"Failed to load model {self.model_name}: {error_msg}"
) from e
def transcribe_audio(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]:
"""
Transcribe audio data to text using the loaded Hugging Face model.
Args:
audio_array (np.ndarray): Audio data as a numpy array
sample_rate (int): Sample rate of the audio data
Returns:
dict: Dictionary containing 'text' and optionally 'word_timestamps'
"""
try:
if self.is_speecht5:
return self._transcribe_speecht5(audio_array, sample_rate)
elif self.is_whisper:
return self._transcribe_whisper(audio_array, sample_rate)
else:
return self._transcribe_wav2vec2(audio_array, sample_rate)
except Exception as e:
raise Exception(f"Transcription failed: {str(e)}") from e
def _transcribe_speecht5(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]:
"""Transcribe audio using SpeechT5 model."""
inputs = self.processor(
audio=audio_array,
sampling_rate=sample_rate,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
outputs = self.model.generate(
input_values=inputs.input_values,
speaker_embeddings=None,
return_dict_in_generate=True
)
# Decode the predicted text
predicted_text = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0]
return {
'text': predicted_text,
'model': self.model_name
}
def _transcribe_whisper(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]:
"""Transcribe audio using Whisper model."""
result = self.model({
"raw": audio_array,
"sampling_rate": sample_rate
})
return {
'text': result['text'],
'model': self.model_name
}
def _transcribe_wav2vec2(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]:
"""Transcribe audio using wav2vec2 model."""
# Resample if needed
if sample_rate != 16000:
audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000)
sample_rate = 16000
# Process the audio
inputs = self.processor(
audio_array,
sampling_rate=sample_rate,
return_tensors="pt",
padding=True
).to(self.device)
with torch.no_grad():
logits = self.model(inputs.input_values).logits
# Get the predicted token ids
predicted_ids = torch.argmax(logits, dim=-1)
# Decode the token ids to text
transcription = self.processor.batch_decode(predicted_ids)[0]
return {
'text': transcription,
'model': self.model_name
}
def transcribe_with_hf(audio_path: str, model_name: str = "openai/whisper-tiny") -> Dict[str, Any]:
"""
Convenience function to transcribe audio using a Hugging Face model.
Args:
audio_path (str): Path to the audio file
model_name (str): Name of the Hugging Face model to use
Returns:
dict: Dictionary containing transcription results
"""
try:
transcriber = HFTranscriber(model_name=model_name)
return transcriber.transcribe_audio(audio_path)
except Exception as e:
return {
'error': str(e),
'model': model_name
}