digiPal / models /stt_processor.py
BladeSzaSza's picture
new design
fe24641
import torch
import torchaudio
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
import numpy as np
from typing import Optional, Union
import librosa
import soundfile as sf
import os
class KyutaiSTTProcessor:
"""Processor for Kyutai Speech-to-Text model"""
def __init__(self, device: str = "cuda"):
self.device = device if torch.cuda.is_available() else "cpu"
self.model = None
self.processor = None
self.model_id = "kyutai/stt-2.6b-en" # English-only model for better accuracy
# Audio processing parameters
self.sample_rate = 16000
self.chunk_length_s = 30 # Process in 30-second chunks
self.max_duration = 120 # Maximum 2 minutes of audio
def load_model(self):
"""Lazy load the STT model"""
if self.model is None:
try:
# Load processor and model
self.processor = AutoProcessor.from_pretrained(self.model_id)
# Model configuration for low VRAM usage
torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
self.model_id,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
self.model.to(self.device)
# Enable better generation settings
self.model.generation_config.language = "english"
self.model.generation_config.task = "transcribe"
self.model.generation_config.forced_decoder_ids = None
except Exception as e:
print(f"Failed to load STT model: {e}")
raise
def preprocess_audio(self, audio_path: str) -> np.ndarray:
"""Preprocess audio file for transcription"""
try:
# Load audio file
audio, sr = librosa.load(audio_path, sr=None, mono=True)
# Resample if necessary
if sr != self.sample_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
# Limit duration
max_samples = self.max_duration * self.sample_rate
if len(audio) > max_samples:
audio = audio[:max_samples]
# Normalize audio
audio = audio / np.max(np.abs(audio) + 1e-7)
return audio
except Exception as e:
print(f"Error preprocessing audio: {e}")
raise
def transcribe(self, audio_input: Union[str, np.ndarray]) -> str:
"""Transcribe audio to text"""
try:
# Load model if not already loaded
self.load_model()
# Process audio input
if isinstance(audio_input, str):
audio = self.preprocess_audio(audio_input)
else:
audio = audio_input
# Process with model
inputs = self.processor(
audio,
sampling_rate=self.sample_rate,
return_tensors="pt"
).to(self.device)
# Generate transcription
with torch.no_grad():
generated_ids = self.model.generate(
inputs["input_features"],
max_new_tokens=128,
do_sample=False,
num_beams=1 # Greedy decoding for speed
)
# Decode transcription
transcription = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)[0]
# Clean up transcription
transcription = self._clean_transcription(transcription)
return transcription
except Exception as e:
print(f"Transcription error: {e}")
# Return a default description on error
return "Create a unique digital monster companion"
def _clean_transcription(self, text: str) -> str:
"""Clean up transcription output"""
# Remove extra whitespace
text = " ".join(text.split())
# Ensure proper capitalization
if text and text[0].islower():
text = text[0].upper() + text[1:]
# Add period if missing
if text and not text[-1] in '.!?':
text += '.'
return text
def transcribe_streaming(self, audio_stream):
"""Streaming transcription (for future implementation)"""
# This would handle real-time audio streams
# For now, return placeholder
raise NotImplementedError("Streaming transcription not yet implemented")
def to(self, device: str):
"""Move model to specified device"""
self.device = device
if self.model:
self.model.to(device)
def __del__(self):
"""Cleanup when object is destroyed"""
if self.model:
del self.model
if self.processor:
del self.processor
torch.cuda.empty_cache()