Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torchaudio | |
| import spaces # Import spaces module for Zero-GPU | |
| from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
| # Create directories | |
| os.makedirs("transcriptions", exist_ok=True) | |
| # Initialize global models | |
| whisper_model = None | |
| whisper_processor = None | |
| # Model configurations | |
| WHISPER_MODEL_SIZES = { | |
| 'tiny': 'openai/whisper-tiny', | |
| 'base': 'openai/whisper-base', | |
| 'small': 'openai/whisper-small', | |
| 'medium': 'openai/whisper-medium', | |
| 'large': 'openai/whisper-large-v3', | |
| } | |
| # Add spaces.GPU decorator for Zero-GPU support | |
| async def transcribe_audio(audio_file_path, model_size="base", language="en"): | |
| global whisper_model, whisper_processor | |
| try: | |
| # Get model identifier | |
| model_id = WHISPER_MODEL_SIZES.get(model_size.lower(), WHISPER_MODEL_SIZES['base']) | |
| # Load model and processor on first use or if model size changes | |
| if whisper_model is None or whisper_processor is None or (whisper_model and whisper_model.config._name_or_path != model_id): | |
| print(f"Loading Whisper {model_size} model...") | |
| whisper_processor = WhisperProcessor.from_pretrained(model_id) | |
| whisper_model = WhisperForConditionalGeneration.from_pretrained(model_id) | |
| print(f"Model loaded on device: {whisper_model.device}") | |
| # Process audio | |
| speech_array, sample_rate = torchaudio.load(audio_file_path) | |
| # Convert to mono if needed | |
| if speech_array.shape[0] > 1: | |
| speech_array = torch.mean(speech_array, dim=0, keepdim=True) | |
| # Resample to 16kHz if needed | |
| if sample_rate != 16000: | |
| resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
| speech_array = resampler(speech_array) | |
| # Prepare inputs for the model | |
| input_features = whisper_processor( | |
| speech_array.squeeze().numpy(), | |
| sampling_rate=16000, | |
| return_tensors="pt" | |
| ).input_features | |
| # Generate transcription | |
| generation_kwargs = {} | |
| if language: | |
| forced_decoder_ids = whisper_processor.get_decoder_prompt_ids(language=language, task="transcribe") | |
| generation_kwargs["forced_decoder_ids"] = forced_decoder_ids | |
| # Run the model | |
| with torch.no_grad(): | |
| predicted_ids = whisper_model.generate(input_features, **generation_kwargs) | |
| # Decode the output | |
| transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True) | |
| # Return the transcribed text | |
| return transcription[0] | |
| except Exception as e: | |
| print(f"Error during transcription: {str(e)}") | |
| return "" |