Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import time | |
import asyncio | |
from typing import List, Dict, Any, Optional | |
from concurrent.futures import ThreadPoolExecutor | |
import torch | |
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
import uvicorn | |
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq | |
import librosa | |
import numpy as np | |
from fastapi.responses import JSONResponse | |
import gc | |
# Initialize thread pool for background processing | |
thread_pool = ThreadPoolExecutor(max_workers=2) | |
# Environment and model configuration | |
MODEL_NAME = "nyrahealth/CrisperWhisper" | |
BATCH_SIZE = 8 | |
FILE_LIMIT_MB = 30 | |
FILE_EXTENSIONS = [".mp3", ".wav", ".m4a", ".ogg", ".flac"] | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Speech to Text API", | |
description="API for transcribing audio files using the CrisperWhisper model", | |
version="1.0.0" | |
) | |
# Add CORS support | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Response models | |
class TranscriptionChunk(BaseModel): | |
timestamp: List[float] | |
text: str | |
class TranscriptionResponse(BaseModel): | |
text: str | |
chunks: List[TranscriptionChunk] | |
# Setup device and load model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load model and processor at startup | |
async def load_model(): | |
global processor, model | |
print("Loading model and processor...") | |
processor = AutoProcessor.from_pretrained(MODEL_NAME) | |
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME) | |
model.to(device) | |
print("Model loaded successfully!") | |
def load_audio(file_path: str) -> tuple: | |
"""Load audio file efficiently""" | |
try: | |
# Use a faster sr=None first to get the original sampling rate, | |
# then resample only if needed | |
audio_array, orig_sr = librosa.load(file_path, sr=None, mono=True) | |
# Resample only if needed | |
if orig_sr != 16000: | |
audio_array = librosa.resample(audio_array, orig_sr=orig_sr, target_sr=16000) | |
sampling_rate = 16000 | |
else: | |
sampling_rate = orig_sr | |
# Convert to float32 if needed | |
if audio_array.dtype != np.float32: | |
audio_array = audio_array.astype(np.float32) | |
return audio_array, sampling_rate | |
except Exception as e: | |
print(f"Error loading audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error loading audio: {str(e)}") | |
def process_audio_file(file_path: str) -> Dict: | |
"""Process audio file and return transcription with timestamps""" | |
try: | |
# Load audio file efficiently | |
audio_array, sampling_rate = load_audio(file_path) | |
# Process with model | |
inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt") | |
inputs = {key: value.to(device) for key, value in inputs.items()} | |
# Generate transcription with word timestamps | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
return_timestamps=True, | |
return_dict_in_generate=True, | |
output_scores=True, | |
max_new_tokens=256 if len(audio_array) < 160000 else 512, # Adjust based on audio length | |
num_beams=1, # Use greedy decoding for speed | |
) | |
# Extract timestamps and words | |
result = processor.decode(outputs.sequences[0], skip_special_tokens=False, output_word_offsets=True) | |
words_with_timestamps = [] | |
for word in result.word_offsets: | |
words_with_timestamps.append({ | |
"text": word["word"].strip(), | |
"timestamp": [ | |
round(word["start_offset"] / sampling_rate, 2), | |
round(word["end_offset"] / sampling_rate, 2) | |
] | |
}) | |
# Create final response format | |
response_data = { | |
"text": processor.decode(outputs.sequences[0], skip_special_tokens=True), | |
"chunks": words_with_timestamps | |
} | |
# Manual garbage collection to free memory | |
del inputs, outputs, result | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
return response_data | |
except Exception as e: | |
print(f"Error processing audio: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}") | |
async def process_in_background(file_path: str): | |
"""Process audio file in a background thread to prevent blocking""" | |
loop = asyncio.get_event_loop() | |
return await loop.run_in_executor(thread_pool, process_audio_file, file_path) | |
async def transcribe_audio(file: UploadFile = File(...)): | |
""" | |
Transcribe an audio file to text with timestamps for each word. | |
Accepts .mp3, .wav, .m4a, .ogg or .flac files up to 30MB. | |
""" | |
start_time = time.time() | |
# Validate file extension | |
file_ext = os.path.splitext(file.filename)[1].lower() | |
if file_ext not in FILE_EXTENSIONS: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Unsupported file format. Supported formats: {', '.join(FILE_EXTENSIONS)}" | |
) | |
# Create temp file to store upload | |
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file: | |
# Get file content | |
content = await file.read() | |
# Check file size | |
if len(content) > FILE_LIMIT_MB * 1024 * 1024: | |
raise HTTPException( | |
status_code=400, | |
detail=f"File too large. Maximum size: {FILE_LIMIT_MB}MB" | |
) | |
# Write to temp file | |
temp_file.write(content) | |
temp_file_path = temp_file.name | |
try: | |
# Process the audio file in background to prevent blocking | |
result = await process_in_background(temp_file_path) | |
processing_time = time.time() - start_time | |
print(f"Processing completed in {processing_time:.2f} seconds") | |
return JSONResponse(content=result) | |
finally: | |
# Clean up the temp file | |
if os.path.exists(temp_file_path): | |
try: | |
os.unlink(temp_file_path) | |
except Exception as e: | |
print(f"Error deleting temp file: {e}") | |
async def health_check(): | |
"""Health check endpoint""" | |
return {"status": "healthy"} | |
# Simple root endpoint that shows API is running | |
async def root(): | |
return { | |
"message": "Speech-to-Text API is running", | |
"endpoints": { | |
"transcribe": "/transcribe (POST)", | |
"health": "/health (GET)", | |
"docs": "/docs (GET)" | |
}, | |
"model": MODEL_NAME, | |
"device": device | |
} | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) | |
uvicorn.run("app:app", host="0.0.0.0", port=port) |