import torch import torchaudio from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor import numpy as np from typing import Optional, Union import librosa import soundfile as sf import os class KyutaiSTTProcessor: """Processor for Kyutai Speech-to-Text model""" def __init__(self, device: str = "cuda"): self.device = device if torch.cuda.is_available() else "cpu" self.model = None self.processor = None self.model_id = "kyutai/stt-2.6b-en" # English-only model for better accuracy # Audio processing parameters self.sample_rate = 16000 self.chunk_length_s = 30 # Process in 30-second chunks self.max_duration = 120 # Maximum 2 minutes of audio def load_model(self): """Lazy load the STT model""" if self.model is None: try: # Load processor and model self.processor = AutoProcessor.from_pretrained(self.model_id) # Model configuration for low VRAM usage torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 self.model = AutoModelForSpeechSeq2Seq.from_pretrained( self.model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) self.model.to(self.device) # Enable better generation settings self.model.generation_config.language = "english" self.model.generation_config.task = "transcribe" self.model.generation_config.forced_decoder_ids = None except Exception as e: print(f"Failed to load STT model: {e}") raise def preprocess_audio(self, audio_path: str) -> np.ndarray: """Preprocess audio file for transcription""" try: # Load audio file audio, sr = librosa.load(audio_path, sr=None, mono=True) # Resample if necessary if sr != self.sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) # Limit duration max_samples = self.max_duration * self.sample_rate if len(audio) > max_samples: audio = audio[:max_samples] # Normalize audio audio = audio / np.max(np.abs(audio) + 1e-7) return audio except Exception as e: print(f"Error preprocessing audio: {e}") raise def transcribe(self, audio_input: Union[str, np.ndarray]) -> str: """Transcribe audio to text""" try: # Load model if not already loaded self.load_model() # Process audio input if isinstance(audio_input, str): audio = self.preprocess_audio(audio_input) else: audio = audio_input # Process with model inputs = self.processor( audio, sampling_rate=self.sample_rate, return_tensors="pt" ).to(self.device) # Generate transcription with torch.no_grad(): generated_ids = self.model.generate( inputs["input_features"], max_new_tokens=128, do_sample=False, num_beams=1 # Greedy decoding for speed ) # Decode transcription transcription = self.processor.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True )[0] # Clean up transcription transcription = self._clean_transcription(transcription) return transcription except Exception as e: print(f"Transcription error: {e}") # Return a default description on error return "Create a unique digital monster companion" def _clean_transcription(self, text: str) -> str: """Clean up transcription output""" # Remove extra whitespace text = " ".join(text.split()) # Ensure proper capitalization if text and text[0].islower(): text = text[0].upper() + text[1:] # Add period if missing if text and not text[-1] in '.!?': text += '.' return text def transcribe_streaming(self, audio_stream): """Streaming transcription (for future implementation)""" # This would handle real-time audio streams # For now, return placeholder raise NotImplementedError("Streaming transcription not yet implemented") def to(self, device: str): """Move model to specified device""" self.device = device if self.model: self.model.to(device) def __del__(self): """Cleanup when object is destroyed""" if self.model: del self.model if self.processor: del self.processor torch.cuda.empty_cache()