NitinBot001 commited on
Commit
6c3cd1d
·
verified ·
1 Parent(s): 489f2f3

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +183 -23
main.py CHANGED
@@ -1,32 +1,192 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import JSONResponse
3
- import whisperx
4
- import torch
5
  import tempfile
6
- import shutil
7
  import os
 
 
 
8
 
9
- app = FastAPI()
 
 
10
 
11
- # Load model globally to avoid reloading for every request
12
- device = "cuda" if torch.cuda.is_available() else "cpu"
13
- model = whisperx.load_model("medium", device)
14
 
15
- @app.post("/transcribe")
16
- async def transcribe_audio(file: UploadFile = File(...)):
17
- try:
18
- # Save uploaded audio to temp file
19
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
20
- shutil.copyfileobj(file.file, tmp)
21
- temp_audio_path = tmp.name
 
 
 
 
 
 
 
22
 
23
- # Load and process audio
24
- audio = whisperx.load_audio(temp_audio_path)
25
- result = model.transcribe(audio, batch_size=16, return_word_timestamps=True)
 
 
 
26
 
27
- # Clean up temp file
28
- os.remove(temp_audio_path)
 
 
 
 
 
 
 
29
 
30
- return JSONResponse(content=result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  except Exception as e:
32
- return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ import whisper
 
 
3
  import tempfile
 
4
  import os
5
+ from werkzeug.utils import secure_filename
6
+ import logging
7
+ from datetime import datetime
8
 
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
+ app = Flask(__name__)
 
 
14
 
15
+ # Configuration
16
+ app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024 # 100MB max file size
17
+ ALLOWED_EXTENSIONS = {'wav', 'mp3', 'mp4', 'mpeg', 'mpga', 'm4a', 'webm', 'flac'}
18
+
19
+ # Load Whisper model (you can change the model size: tiny, base, small, medium, large)
20
+ MODEL_SIZE = "base" # Change this to your preferred model size
21
+ logger.info(f"Loading Whisper model: {MODEL_SIZE}")
22
+ model = whisper.load_model(MODEL_SIZE)
23
+ logger.info("Whisper model loaded successfully")
24
+
25
+ def allowed_file(filename):
26
+ """Check if the file extension is allowed"""
27
+ return '.' in filename and \
28
+ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
29
 
30
+ def format_timestamp(seconds):
31
+ """Convert seconds to HH:MM:SS.mmm format"""
32
+ hours = int(seconds // 3600)
33
+ minutes = int((seconds % 3600) // 60)
34
+ secs = seconds % 60
35
+ return f"{hours:02d}:{minutes:02d}:{secs:06.3f}"
36
 
37
+ @app.route('/', methods=['GET'])
38
+ def health_check():
39
+ """Health check endpoint"""
40
+ return jsonify({
41
+ "status": "healthy",
42
+ "message": "Whisper Transcription API is running",
43
+ "model": MODEL_SIZE,
44
+ "timestamp": datetime.now().isoformat()
45
+ })
46
 
47
+ @app.route('/transcribe', methods=['POST'])
48
+ def transcribe_audio():
49
+ """
50
+ Transcribe audio file and return word-level timestamps
51
+
52
+ Expected form data:
53
+ - audio_file: The audio file to transcribe
54
+ - language (optional): Language code (e.g., 'en', 'es', 'fr')
55
+ - task (optional): 'transcribe' or 'translate' (default: transcribe)
56
+ """
57
+ try:
58
+ # Check if audio file is present
59
+ if 'audio_file' not in request.files:
60
+ return jsonify({'error': 'No audio file provided'}), 400
61
+
62
+ file = request.files['audio_file']
63
+
64
+ if file.filename == '':
65
+ return jsonify({'error': 'No file selected'}), 400
66
+
67
+ if not allowed_file(file.filename):
68
+ return jsonify({
69
+ 'error': f'File type not allowed. Supported formats: {", ".join(ALLOWED_EXTENSIONS)}'
70
+ }), 400
71
+
72
+ # Get optional parameters
73
+ language = request.form.get('language', None)
74
+ task = request.form.get('task', 'transcribe')
75
+
76
+ if task not in ['transcribe', 'translate']:
77
+ return jsonify({'error': 'Task must be either "transcribe" or "translate"'}), 400
78
+
79
+ # Save uploaded file temporarily
80
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{file.filename.rsplit('.', 1)[1].lower()}") as tmp_file:
81
+ file.save(tmp_file.name)
82
+ temp_path = tmp_file.name
83
+
84
+ logger.info(f"Processing file: {file.filename}")
85
+
86
+ try:
87
+ # Transcribe with word-level timestamps
88
+ result = model.transcribe(
89
+ temp_path,
90
+ language=language,
91
+ task=task,
92
+ word_timestamps=True,
93
+ verbose=False
94
+ )
95
+
96
+ # Extract word-level data
97
+ word_segments = []
98
+ for segment in result.get("segments", []):
99
+ if "words" in segment:
100
+ for word_data in segment["words"]:
101
+ word_segments.append({
102
+ "word": word_data.get("word", "").strip(),
103
+ "start": word_data.get("start", 0),
104
+ "end": word_data.get("end", 0),
105
+ "start_formatted": format_timestamp(word_data.get("start", 0)),
106
+ "end_formatted": format_timestamp(word_data.get("end", 0)),
107
+ "confidence": word_data.get("probability", 0)
108
+ })
109
+
110
+ # Prepare response
111
+ response_data = {
112
+ "success": True,
113
+ "filename": secure_filename(file.filename),
114
+ "language": result.get("language", "unknown"),
115
+ "task": task,
116
+ "duration": result.get("segments", [{}])[-1].get("end", 0) if result.get("segments") else 0,
117
+ "text": result.get("text", ""),
118
+ "word_count": len(word_segments),
119
+ "segments": result.get("segments", []),
120
+ "words": word_segments,
121
+ "model_used": MODEL_SIZE,
122
+ "processing_time": None # You can add timing if needed
123
+ }
124
+
125
+ logger.info(f"Successfully transcribed {len(word_segments)} words from {file.filename}")
126
+ return jsonify(response_data)
127
+
128
+ except Exception as e:
129
+ logger.error(f"Transcription error: {str(e)}")
130
+ return jsonify({'error': f'Transcription failed: {str(e)}'}), 500
131
+
132
+ finally:
133
+ # Clean up temporary file
134
+ if os.path.exists(temp_path):
135
+ os.unlink(temp_path)
136
+
137
  except Exception as e:
138
+ logger.error(f"API error: {str(e)}")
139
+ return jsonify({'error': f'Server error: {str(e)}'}), 500
140
+
141
+ @app.route('/models', methods=['GET'])
142
+ def available_models():
143
+ """Get information about available Whisper models"""
144
+ models_info = {
145
+ "current_model": MODEL_SIZE,
146
+ "available_models": {
147
+ "tiny": {"size": "~39 MB", "speed": "~32x", "accuracy": "lowest"},
148
+ "base": {"size": "~74 MB", "speed": "~16x", "accuracy": "low"},
149
+ "small": {"size": "~244 MB", "speed": "~6x", "accuracy": "medium"},
150
+ "medium": {"size": "~769 MB", "speed": "~2x", "accuracy": "high"},
151
+ "large": {"size": "~1550 MB", "speed": "~1x", "accuracy": "highest"}
152
+ },
153
+ "supported_languages": [
154
+ "en", "zh", "de", "es", "ru", "ko", "fr", "ja", "pt", "tr", "pl", "ca", "nl",
155
+ "ar", "sv", "it", "id", "hi", "fi", "vi", "he", "uk", "el", "ms", "cs", "ro",
156
+ "da", "hu", "ta", "no", "th", "ur", "hr", "bg", "lt", "la", "mi", "ml", "cy",
157
+ "sk", "te", "fa", "lv", "bn", "sr", "az", "sl", "kn", "et", "mk", "br", "eu",
158
+ "is", "hy", "ne", "mn", "bs", "kk", "sq", "sw", "gl", "mr", "pa", "si", "km",
159
+ "sn", "yo", "so", "af", "oc", "ka", "be", "tg", "sd", "gu", "am", "yi", "lo",
160
+ "uz", "fo", "ht", "ps", "tk", "nn", "mt", "sa", "lb", "my", "bo", "tl", "mg",
161
+ "as", "tt", "haw", "ln", "ha", "ba", "jw", "su"
162
+ ]
163
+ }
164
+ return jsonify(models_info)
165
+
166
+ @app.errorhandler(413)
167
+ def too_large(e):
168
+ return jsonify({'error': 'File too large. Maximum size is 100MB'}), 413
169
+
170
+ @app.errorhandler(404)
171
+ def not_found(e):
172
+ return jsonify({'error': 'Endpoint not found'}), 404
173
+
174
+ @app.errorhandler(500)
175
+ def internal_error(e):
176
+ return jsonify({'error': 'Internal server error'}), 500
177
+
178
+ if __name__ == '__main__':
179
+ print(f"""
180
+ Whisper Transcription API Server
181
+ ================================
182
+ Model: {MODEL_SIZE}
183
+ Endpoints:
184
+ - GET / : Health check
185
+ - POST /transcribe : Transcribe audio file
186
+ - GET /models : Available models info
187
+
188
+ Supported formats: {', '.join(ALLOWED_EXTENSIONS)}
189
+ Max file size: 100MB
190
+ """)
191
+
192
+ app.run(debug=True, host='0.0.0.0', port=5000)