Spaces:
Running
Running
| import os | |
| import torch | |
| import numpy as np | |
| import librosa | |
| import torchaudio | |
| from typing import Dict, Any, Optional, Union, Tuple, List | |
| import warnings | |
| from transformers import ( | |
| AutoModelForCTC, | |
| AutoProcessor, | |
| pipeline, | |
| SpeechT5Processor, | |
| SpeechT5ForSpeechToText, | |
| SpeechT5HifiGan | |
| ) | |
| # Suppress specific warnings | |
| warnings.filterwarnings("ignore", message=".*gradient_checkpointing*.") | |
| warnings.filterwarnings("ignore", message="Using the model-agnostic default `max_length`") | |
| warnings.filterwarnings("ignore", message="You are using the default legacy behaviour") | |
| class HFTranscriber: | |
| def __init__(self, model_name: str = "facebook/wav2vec2-base-960h"): | |
| """ | |
| Initialize the Hugging Face transcriber with a pre-trained model. | |
| Args: | |
| model_name (str): Name of the Hugging Face model to use for transcription. | |
| Supported models: | |
| - "facebook/wav2vec2-base-960h" (default) | |
| - "openai/whisper-small" | |
| - "microsoft/speecht5_asr" | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_name = model_name | |
| self.processor = None | |
| self.model = None | |
| self.vocoder = None | |
| self.is_speecht5 = "speecht5" in model_name.lower() | |
| self.is_whisper = "whisper" in model_name.lower() | |
| self._load_model() | |
| def _load_model(self): | |
| """Load the model and processor based on the model type with authentication.""" | |
| try: | |
| # Try to get Hugging Face token from environment | |
| hf_token = (os.environ.get('HUGGINGFACE_TOKEN') or os.environ.get('HF_TOKEN') or (st.secrets.get('HUGGINGFACE_TOKEN') if 'secrets' in globals() and hasattr(st.secrets, 'get') else None) or (st.secrets.get('HF_TOKEN') if 'secrets' in globals() and hasattr(st.secrets, 'get') else None)) | |
| if not hf_token: | |
| st.sidebar.error("No Hugging Face token found. Using public access (rate limited).Please add it to your environment variables as HUGGINGFACE_TOKEN or HF_TOKEN.") | |
| #Configure headers for API requests | |
| headers ={} | |
| if hf_token: | |
| headers['Authorization'] = f'Bearer {hf_token}' | |
| #Configure model loading parameters | |
| load_kwargs = {'token': hf_token, 'use_auth_token': hf_token, 'local_files_only': False, 'device': 'cuda' if torch.cuda.is_available() else 'cpu'} | |
| #Remove None values | |
| load_kwargs = {k: v for k, v in load_kwargs.items() if v is not None} | |
| #Rest of model loading code..... | |
| if self.is_speecht5: | |
| # Load SpeechT5 model and processor | |
| self.processor = SpeechT5Processor.from_pretrained( | |
| self.model_name, | |
| **load_kwargs | |
| ) | |
| self.model = SpeechT5ForSpeechToText.from_pretrained( | |
| self.model_name, | |
| **load_kwargs | |
| ) | |
| self.vocoder = SpeechT5HifiGan.from_pretrained( | |
| "microsoft/speecht5_hifigan", | |
| **load_kwargs | |
| ) | |
| self.model = self.model.to(self.device) | |
| self.vocoder = self.vocoder.to(self.device) | |
| self.model.eval() | |
| self.vocoder.eval() | |
| elif self.is_whisper: | |
| # For whisper, we'll use the pipeline with the token | |
| self.model = pipeline( | |
| "automatic-speech-recognition", | |
| model=self.model_name, | |
| token=hf_token, # Pass token directly | |
| device=0 if self.device == "cuda" else -1 | |
| ) | |
| self.processor = None # Not needed when using pipeline | |
| else: | |
| # Load wav2vec2 model and processor | |
| self.processor = AutoProcessor.from_pretrained( | |
| self.model_name, | |
| **load_kwargs | |
| ) | |
| self.model = AutoModelForCTC.from_pretrained( | |
| self.model_name, | |
| **load_kwargs | |
| ) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "401" in error_msg or "401" in str(e.__cause__): | |
| raise Exception( | |
| "Authentication failed. Please check your Hugging Face token.\n" | |
| "1. Get your token from https://huggingface.co/settings/tokens\n" | |
| "2. Add it to your environment variables as HUGGINGFACE_TOKEN" | |
| ) from e | |
| elif "404" in error_msg: | |
| raise Exception( | |
| f"Model {self.model_name} not found. Please check the model name." | |
| ) from e | |
| else: | |
| raise Exception( | |
| f"Failed to load model {self.model_name}: {error_msg}" | |
| ) from e | |
| def transcribe_audio(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]: | |
| """ | |
| Transcribe audio data to text using the loaded Hugging Face model. | |
| Args: | |
| audio_array (np.ndarray): Audio data as a numpy array | |
| sample_rate (int): Sample rate of the audio data | |
| Returns: | |
| dict: Dictionary containing 'text' and optionally 'word_timestamps' | |
| """ | |
| try: | |
| if self.is_speecht5: | |
| return self._transcribe_speecht5(audio_array, sample_rate) | |
| elif self.is_whisper: | |
| return self._transcribe_whisper(audio_array, sample_rate) | |
| else: | |
| return self._transcribe_wav2vec2(audio_array, sample_rate) | |
| except Exception as e: | |
| raise Exception(f"Transcription failed: {str(e)}") from e | |
| def _transcribe_speecht5(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]: | |
| """Transcribe audio using SpeechT5 model.""" | |
| inputs = self.processor( | |
| audio=audio_array, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| input_values=inputs.input_values, | |
| speaker_embeddings=None, | |
| return_dict_in_generate=True | |
| ) | |
| # Decode the predicted text | |
| predicted_text = self.processor.batch_decode(outputs.sequences, skip_special_tokens=True)[0] | |
| return { | |
| 'text': predicted_text, | |
| 'model': self.model_name | |
| } | |
| def _transcribe_whisper(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]: | |
| """Transcribe audio using Whisper model.""" | |
| result = self.model({ | |
| "raw": audio_array, | |
| "sampling_rate": sample_rate | |
| }) | |
| return { | |
| 'text': result['text'], | |
| 'model': self.model_name | |
| } | |
| def _transcribe_wav2vec2(self, audio_array: np.ndarray, sample_rate: int) -> Dict[str, Any]: | |
| """Transcribe audio using wav2vec2 model.""" | |
| # Resample if needed | |
| if sample_rate != 16000: | |
| audio_array = librosa.resample(audio_array, orig_sr=sample_rate, target_sr=16000) | |
| sample_rate = 16000 | |
| # Process the audio | |
| inputs = self.processor( | |
| audio_array, | |
| sampling_rate=sample_rate, | |
| return_tensors="pt", | |
| padding=True | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| logits = self.model(inputs.input_values).logits | |
| # Get the predicted token ids | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| # Decode the token ids to text | |
| transcription = self.processor.batch_decode(predicted_ids)[0] | |
| return { | |
| 'text': transcription, | |
| 'model': self.model_name | |
| } | |
| def transcribe_with_hf(audio_path: str, model_name: str = "openai/whisper-tiny") -> Dict[str, Any]: | |
| """ | |
| Convenience function to transcribe audio using a Hugging Face model. | |
| Args: | |
| audio_path (str): Path to the audio file | |
| model_name (str): Name of the Hugging Face model to use | |
| Returns: | |
| dict: Dictionary containing transcription results | |
| """ | |
| try: | |
| transcriber = HFTranscriber(model_name=model_name) | |
| return transcriber.transcribe_audio(audio_path) | |
| except Exception as e: | |
| return { | |
| 'error': str(e), | |
| 'model': model_name | |
| } |