File size: 7,228 Bytes
ceb7caa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import os
import tempfile
import time
import asyncio
from typing import List, Dict, Any, Optional
from concurrent.futures import ThreadPoolExecutor

import torch
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
import librosa
import numpy as np
from fastapi.responses import JSONResponse
import gc

# Initialize thread pool for background processing
thread_pool = ThreadPoolExecutor(max_workers=2)

# Environment and model configuration
MODEL_NAME = "nyrahealth/CrisperWhisper"
BATCH_SIZE = 8
FILE_LIMIT_MB = 30
FILE_EXTENSIONS = [".mp3", ".wav", ".m4a", ".ogg", ".flac"]

# Initialize FastAPI app
app = FastAPI(
    title="Speech to Text API",
    description="API for transcribing audio files using the CrisperWhisper model",
    version="1.0.0"
)

# Add CORS support
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Response models
class TranscriptionChunk(BaseModel):
    timestamp: List[float]
    text: str

class TranscriptionResponse(BaseModel):
    text: str
    chunks: List[TranscriptionChunk]

# Setup device and load model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load model and processor at startup
@app.on_event("startup")
async def load_model():
    global processor, model
    print("Loading model and processor...")
    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_NAME)
    model.to(device)
    print("Model loaded successfully!")

def load_audio(file_path: str) -> tuple:
    """Load audio file efficiently"""
    try:
        # Use a faster sr=None first to get the original sampling rate,
        # then resample only if needed
        audio_array, orig_sr = librosa.load(file_path, sr=None, mono=True)
        
        # Resample only if needed
        if orig_sr != 16000:
            audio_array = librosa.resample(audio_array, orig_sr=orig_sr, target_sr=16000)
            sampling_rate = 16000
        else:
            sampling_rate = orig_sr
            
        # Convert to float32 if needed
        if audio_array.dtype != np.float32:
            audio_array = audio_array.astype(np.float32)
            
        return audio_array, sampling_rate
        
    except Exception as e:
        print(f"Error loading audio: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error loading audio: {str(e)}")

def process_audio_file(file_path: str) -> Dict:
    """Process audio file and return transcription with timestamps"""
    try:
        # Load audio file efficiently
        audio_array, sampling_rate = load_audio(file_path)
        
        # Process with model
        inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt")
        inputs = {key: value.to(device) for key, value in inputs.items()}
        
        # Generate transcription with word timestamps
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                return_timestamps=True,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=256 if len(audio_array) < 160000 else 512,  # Adjust based on audio length
                num_beams=1,  # Use greedy decoding for speed
            )
        
        # Extract timestamps and words
        result = processor.decode(outputs.sequences[0], skip_special_tokens=False, output_word_offsets=True)
        words_with_timestamps = []
        
        for word in result.word_offsets:
            words_with_timestamps.append({
                "text": word["word"].strip(),
                "timestamp": [
                    round(word["start_offset"] / sampling_rate, 2),
                    round(word["end_offset"] / sampling_rate, 2)
                ]
            })
        
        # Create final response format
        response_data = {
            "text": processor.decode(outputs.sequences[0], skip_special_tokens=True),
            "chunks": words_with_timestamps
        }
        
        # Manual garbage collection to free memory
        del inputs, outputs, result
        if device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
        
        return response_data
    
    except Exception as e:
        print(f"Error processing audio: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error processing audio: {str(e)}")

async def process_in_background(file_path: str):
    """Process audio file in a background thread to prevent blocking"""
    loop = asyncio.get_event_loop()
    return await loop.run_in_executor(thread_pool, process_audio_file, file_path)

@app.post("/transcribe", response_model=TranscriptionResponse)
async def transcribe_audio(file: UploadFile = File(...)):
    """
    Transcribe an audio file to text with timestamps for each word.
    
    Accepts .mp3, .wav, .m4a, .ogg or .flac files up to 30MB.
    """
    start_time = time.time()
    
    # Validate file extension
    file_ext = os.path.splitext(file.filename)[1].lower()
    if file_ext not in FILE_EXTENSIONS:
        raise HTTPException(
            status_code=400, 
            detail=f"Unsupported file format. Supported formats: {', '.join(FILE_EXTENSIONS)}"
        )
    
    # Create temp file to store upload
    with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as temp_file:
        # Get file content
        content = await file.read()
        
        # Check file size
        if len(content) > FILE_LIMIT_MB * 1024 * 1024:
            raise HTTPException(
                status_code=400,
                detail=f"File too large. Maximum size: {FILE_LIMIT_MB}MB"
            )
        
        # Write to temp file
        temp_file.write(content)
        temp_file_path = temp_file.name
    
    try:
        # Process the audio file in background to prevent blocking
        result = await process_in_background(temp_file_path)
        processing_time = time.time() - start_time
        print(f"Processing completed in {processing_time:.2f} seconds")
        
        return JSONResponse(content=result)
    
    finally:
        # Clean up the temp file
        if os.path.exists(temp_file_path):
            try:
                os.unlink(temp_file_path)
            except Exception as e:
                print(f"Error deleting temp file: {e}")

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {"status": "healthy"}

# Simple root endpoint that shows API is running
@app.get("/")
async def root():
    return {
        "message": "Speech-to-Text API is running",
        "endpoints": {
            "transcribe": "/transcribe (POST)",
            "health": "/health (GET)",
            "docs": "/docs (GET)"
        },
        "model": MODEL_NAME,
        "device": device
    }

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))
    uvicorn.run("app:app", host="0.0.0.0", port=port)