karim23657 commited on
Commit
3e0bc8b
·
verified ·
1 Parent(s): d434d9d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +412 -0
app.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import gradio as gr
4
+ import shutil
5
+ import subprocess
6
+ import requests
7
+ import tarfile
8
+ from pathlib import Path
9
+ import soundfile as sf
10
+ import sherpa_onnx
11
+ from deep_translator import GoogleTranslator
12
+ import numpy as np
13
+ from iso639 import Lang
14
+ import pycountry
15
+
16
+
17
+ models = [
18
+ ['mms fa','https://huggingface.co/willwade/mms-tts-multilingual-models-onnx/resolve/main/fas'],
19
+ ['coqui-vits-female1-karim23657','https://huggingface.co/karim23657/persian-tts-vits/tree/main/persian-tts-female1-vits-coqui'],
20
+ ['coqui-vits-male1-karim23657','https://huggingface.co/karim23657/persian-tts-vits/tree/main/persian-tts-male1-vits-coqui'],
21
+ ['vits-piper-fa-gyro-medium','https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-fa_IR-gyro-medium.tar.bz2'],
22
+ ['piper-fa-amir-medium','https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-fa_IR-amir-medium.tar.bz2'],
23
+ ['vits-mimic3-fa-haaniye_low','https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-mimic3-fa-haaniye_low.tar.bz2'],
24
+ ['',''],
25
+ ]
26
+ model_info = models[model_id]
27
+ def download_and_extract_model(url, destination):
28
+ """Download and extract the model files."""
29
+ print(f"Downloading from URL: {url}")
30
+ print(f"Destination: {destination}")
31
+
32
+ # Convert Hugging Face URL format if needed
33
+ if "huggingface.co" in url:
34
+ # Replace /tree/main/ with /resolve/main/ for direct file download
35
+ base_url = url.replace("/tree/main/", "/resolve/main/")
36
+ model_id = base_url.split("/")[-1]
37
+
38
+ # Check if this is an MMS model
39
+ is_mms_model = True
40
+
41
+ if is_mms_model:
42
+ # MMS models have both model.onnx and tokens.txt
43
+ model_url = f"{base_url}/model.onnx"
44
+ tokens_url = f"{base_url}/tokens.txt"
45
+
46
+ # Download model.onnx
47
+ print("Downloading model.onnx...")
48
+ model_path = os.path.join(destination, "model.onnx")
49
+ response = requests.get(model_url, stream=True)
50
+ if response.status_code != 200:
51
+ raise Exception(f"Failed to download model from {model_url}. Status code: {response.status_code}")
52
+
53
+ total_size = int(response.headers.get('content-length', 0))
54
+ block_size = 8192
55
+ downloaded = 0
56
+
57
+ print(f"Total size: {total_size / (1024*1024):.1f} MB")
58
+ with open(model_path, "wb") as f:
59
+ for chunk in response.iter_content(chunk_size=block_size):
60
+ if chunk:
61
+ f.write(chunk)
62
+ downloaded += len(chunk)
63
+ if total_size > 0:
64
+ percent = int((downloaded / total_size) * 100)
65
+ if percent % 10 == 0:
66
+ print(f" {percent}%", end="", flush=True)
67
+ print("\nModel download complete")
68
+
69
+ # Download tokens.txt
70
+ print("Downloading tokens.txt...")
71
+ tokens_path = os.path.join(destination, "tokens.txt")
72
+ response = requests.get(tokens_url, stream=True)
73
+ if response.status_code != 200:
74
+ raise Exception(f"Failed to download tokens from {tokens_url}. Status code: {response.status_code}")
75
+
76
+ with open(tokens_path, "wb") as f:
77
+ f.write(response.content)
78
+ print("Tokens download complete")
79
+
80
+ return
81
+ else:
82
+ # Other models are stored as tar.bz2 files
83
+ url = f"{base_url}.tar.bz2"
84
+
85
+ # Try the URL
86
+ response = requests.get(url, stream=True)
87
+ if response.status_code != 200:
88
+ raise Exception(f"Failed to download model from {url}. Status code: {response.status_code}")
89
+
90
+ # Check if this is a Git LFS file pointer
91
+ content_start = response.content[:100].decode('utf-8', errors='ignore')
92
+ if content_start.startswith('version https://git-lfs.github.com/spec/v1'):
93
+ raise Exception(f"Received Git LFS pointer instead of file content from {url}")
94
+
95
+ # Create model directory if it doesn't exist
96
+ os.makedirs(destination, exist_ok=True)
97
+
98
+ # For non-MMS models, handle tar.bz2 files
99
+ tar_path = os.path.join(destination, "model.tar.bz2")
100
+
101
+ # Download the file
102
+ print("Downloading model archive...")
103
+ response = requests.get(url, stream=True)
104
+ total_size = int(response.headers.get('content-length', 0))
105
+ block_size = 8192
106
+ downloaded = 0
107
+
108
+ print(f"Total size: {total_size / (1024*1024):.1f} MB")
109
+ with open(tar_path, "wb") as f:
110
+ for chunk in response.iter_content(chunk_size=block_size):
111
+ if chunk:
112
+ f.write(chunk)
113
+ downloaded += len(chunk)
114
+ if total_size > 0:
115
+ percent = int((downloaded / total_size) * 100)
116
+ if percent % 10 == 0:
117
+ print(f" {percent}%", end="", flush=True)
118
+ print("\nDownload complete")
119
+
120
+ # Extract the tar.bz2 file
121
+ print(f"Extracting {tar_path} to {destination}")
122
+ try:
123
+ with tarfile.open(tar_path, "r:bz2") as tar:
124
+ tar.extractall(path=destination)
125
+ os.remove(tar_path)
126
+ print("Extraction complete")
127
+ except Exception as e:
128
+ print(f"Error during extraction: {str(e)}")
129
+ raise
130
+
131
+ print("Contents of destination directory:")
132
+ for root, dirs, files in os.walk(destination):
133
+ print(f"\nDirectory: {root}")
134
+ if dirs:
135
+ print(" Subdirectories:", dirs)
136
+ if files:
137
+ print(" Files:", files)
138
+
139
+ def dl_espeak_data():
140
+ # Download the file
141
+ tar_path='espeak-ng-data.tar.bz2'
142
+ print("Downloading model archive...")
143
+ response = requests.get('https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/espeak-ng-data.tar.bz2', stream=True)
144
+ total_size = int(response.headers.get('content-length', 0))
145
+ block_size = 8192
146
+ downloaded = 0
147
+
148
+ print(f"Total size: {total_size / (1024*1024):.1f} MB")
149
+ with open(tar_path, "wb") as f:
150
+ for chunk in response.iter_content(chunk_size=block_size):
151
+ if chunk:
152
+ f.write(chunk)
153
+ downloaded += len(chunk)
154
+ if total_size > 0:
155
+ percent = int((downloaded / total_size) * 100)
156
+ if percent % 10 == 0:
157
+ print(f" {percent}%", end="", flush=True)
158
+ print("\nDownload complete")
159
+
160
+ # Extract the tar.bz2 file
161
+ destination=os.path.abspath(__file__)
162
+ print(f"Extracting {tar_path} to {destination}")
163
+ try:
164
+ with tarfile.open(tar_path, "r:bz2") as tar:
165
+ tar.extractall(path=destination)
166
+ os.remove(tar_path)
167
+ print("Extraction complete")
168
+ except Exception as e:
169
+ print(f"Error during extraction: {str(e)}")
170
+ raise
171
+
172
+ print("Contents of destination directory:")
173
+ for root, dirs, files in os.walk(destination):
174
+ print(f"\nDirectory: {root}")
175
+ if dirs:
176
+ print(" Subdirectories:", dirs)
177
+ if files:
178
+ print(" Files:", files)
179
+ def find_model_files(model_dir):
180
+ """Find model files in the given directory and its subdirectories."""
181
+ model_files = {}
182
+
183
+ # Check if this is an MMS model
184
+ is_mms = True
185
+
186
+ for root, _, files in os.walk(model_dir):
187
+ for file in files:
188
+ file_path = os.path.join(root, file)
189
+
190
+ # Model file
191
+ if file.endswith('.onnx'):
192
+ model_files['model'] = file_path
193
+
194
+ # Tokens file
195
+ elif file == 'tokens.txt':
196
+ model_files['tokens'] = file_path
197
+
198
+ # Lexicon file (only for non-MMS models)
199
+ elif file == 'lexicon.txt' and not is_mms:
200
+ model_files['lexicon'] = file_path
201
+
202
+ # Create empty lexicon file if needed (only for non-MMS models)
203
+ if not is_mms and 'model' in model_files and 'lexicon' not in model_files:
204
+ model_dir = os.path.dirname(model_files['model'])
205
+ lexicon_path = os.path.join(model_dir, 'lexicon.txt')
206
+ with open(lexicon_path, 'w', encoding='utf-8') as f:
207
+ pass # Create empty file
208
+ model_files['lexicon'] = lexicon_path
209
+
210
+ return model_files if 'model' in model_files else {}
211
+
212
+ def generate_audio(text, model_info):
213
+ """Generate audio from text using the specified model."""
214
+ try:
215
+ model_dir = os.path.join("./models", model_info['id'])
216
+
217
+ print(f"\nLooking for model in: {model_dir}")
218
+
219
+ # Download model if it doesn't exist
220
+ if not os.path.exists(model_dir):
221
+ print(f"Model directory doesn't exist, downloading {model_info['id']}...")
222
+ os.makedirs(model_dir, exist_ok=True)
223
+ download_and_extract_model(model_info['url'], model_dir)
224
+
225
+ print(f"Contents of {model_dir}:")
226
+ for item in os.listdir(model_dir):
227
+ item_path = os.path.join(model_dir, item)
228
+ if os.path.isdir(item_path):
229
+ print(f" Directory: {item}")
230
+ print(f" Contents: {os.listdir(item_path)}")
231
+ else:
232
+ print(f" File: {item}")
233
+
234
+ # Find and validate model files
235
+ model_files = find_model_files(model_dir)
236
+ if not model_files or 'model' not in model_files:
237
+ raise ValueError(f"Could not find required model files in {model_dir}")
238
+
239
+ print("\nFound model files:")
240
+ print(f"Model: {model_files['model']}")
241
+ print(f"Tokens: {model_files.get('tokens', 'Not found')}")
242
+ print(f"Lexicon: {model_files.get('lexicon', 'Not required for MMS')}\n")
243
+
244
+ # Check if this is an MMS model
245
+ is_mms = 'mms' in os.path.basename(model_dir).lower()
246
+
247
+ # Create configuration based on model type
248
+ if is_mms:
249
+ if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
250
+ raise ValueError("tokens.txt is required for MMS models")
251
+
252
+ # MMS models use tokens.txt and no lexicon
253
+ vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
254
+ model_files['model'], # model
255
+ '', # lexicon
256
+ model_files['tokens'], # tokens
257
+ '', # data_dir
258
+ '', # dict_dir
259
+ 0.667, # noise_scale
260
+ 0.8, # noise_scale_w
261
+ 1.0 # length_scale
262
+ )
263
+ else:
264
+ # Non-MMS models use lexicon.txt
265
+ if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
266
+ raise ValueError("tokens.txt is required for VITS models")
267
+
268
+ # Set data dir if it exists
269
+ espeak_data = os.path.join(os.path.dirname(model_files['model']), 'espeak-ng-data')
270
+ data_dir = espeak_data if os.path.exists(espeak_data) else ''
271
+
272
+ # Get lexicon path if it exists
273
+ lexicon = model_files.get('lexicon', '') if os.path.exists(model_files.get('lexicon', '')) else ''
274
+
275
+ # Create VITS model config
276
+ vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
277
+ model_files['model'], # model
278
+ lexicon, # lexicon
279
+ model_files['tokens'], # tokens
280
+ data_dir, # data_dir
281
+ '', # dict_dir
282
+ 0.667, # noise_scale
283
+ 0.8, # noise_scale_w
284
+ 1.0 # length_scale
285
+ )
286
+
287
+ # Create the model config with VITS
288
+ model_config = sherpa_onnx.OfflineTtsModelConfig()
289
+ model_config.vits = vits_config
290
+
291
+ # Create TTS configuration
292
+ config = sherpa_onnx.OfflineTtsConfig(
293
+ model=model_config,
294
+ max_num_sentences=2
295
+ )
296
+
297
+ # Initialize TTS engine
298
+ tts = sherpa_onnx.OfflineTts(config)
299
+
300
+ # Generate audio
301
+ audio_data = tts.generate(text)
302
+
303
+ # Ensure we have valid audio data
304
+ if audio_data is None or len(audio_data.samples) == 0:
305
+ raise ValueError("Failed to generate audio - no data generated")
306
+
307
+ # Convert samples list to numpy array and normalize
308
+ audio_array = np.array(audio_data.samples, dtype=np.float32)
309
+ if np.any(audio_array): # Check if array is not all zeros
310
+ audio_array = audio_array / np.abs(audio_array).max()
311
+ else:
312
+ raise ValueError("Generated audio is empty")
313
+
314
+ # Return in Gradio's expected format (numpy array, sample rate)
315
+ return (audio_array, audio_data.sample_rate)
316
+
317
+ except Exception as e:
318
+ error_msg = str(e)
319
+ # Check for OOV or token conversion errors
320
+ if "out of vocabulary" in error_msg.lower() or "token" in error_msg.lower():
321
+ error_msg = f"Text contains unsupported characters: {error_msg}"
322
+ print(f"Error generating audio: {error_msg}")
323
+ print(f"Error in TTS generation: {error_msg}")
324
+ raise
325
+
326
+ def tts_interface(selected_model, text, status_output):
327
+ try:
328
+ if not text.strip():
329
+ return None, "Please enter some text"
330
+
331
+ # Get model ID from the display name mapping
332
+ model_id = models_by_display.get(selected_model)
333
+ if not model_id or model_id not in models:
334
+ return None, "Please select a model"
335
+
336
+
337
+ # Store original text for status message
338
+ original_text = text
339
+
340
+
341
+ try:
342
+ # Update status with language info
343
+ lang_info = model_info.get('language', [{}])[0]
344
+ lang_name = lang_info.get('language_name', 'Unknown')
345
+ voice_name = model_info.get('name', model_id)
346
+ status = f"Generating speech using {voice_name} ({lang_name})..."
347
+
348
+ # Generate audio
349
+ audio_data, sample_rate = generate_audio(text, model_info)
350
+
351
+ # Include translation info in final status if text was actually translated
352
+ final_status = f"Generated speech using {voice_name} ({lang_name})"
353
+ final_status += f"\nText: '{text}'"
354
+
355
+ return (sample_rate, audio_data), final_status
356
+ except ValueError as e:
357
+ # Handle known errors with user-friendly messages
358
+ error_msg = str(e)
359
+ if "cannot process some words" in error_msg.lower():
360
+ return None, error_msg
361
+ return None, f"Error: {error_msg}"
362
+
363
+ except Exception as e:
364
+ print(f"Error in TTS generation: {str(e)}")
365
+ error_msg = str(e)
366
+ return None, f"Error: {error_msg}"
367
+
368
+ # Gradio Interface
369
+ with gr.Blocks() as app:
370
+ gr.Markdown("# Sherpa-ONNX متن به گفتار")
371
+ with gr.Row():
372
+ with gr.Column():
373
+ model_dropdown = gr.Radio(
374
+ choices=dropdown_choices,
375
+ label="مدل",
376
+ value=dropdown_choices[0] if dropdown_choices else None
377
+ )
378
+ text_input = gr.Textbox(
379
+ label="متن",
380
+ placeholder="متن را وارد کنید ...",
381
+ lines=3
382
+ )
383
+
384
+ with gr.Row():
385
+ generate_btn = gr.Button("بگو")
386
+ stop_btn = gr.Button("توقف")
387
+
388
+ with gr.Column():
389
+ audio_output = gr.Audio(
390
+ label="گفتار",
391
+ type="numpy"
392
+ )
393
+ status_text = gr.Textbox(
394
+ label="وضعیت",
395
+ interactive=False
396
+ )
397
+
398
+
399
+ # Set up event handlers
400
+ gen_event = generate_btn.click(
401
+ fn=tts_interface,
402
+ inputs=[model_dropdown, text_input, status_text],
403
+ outputs=[audio_output, status_text]
404
+ )
405
+
406
+ stop_btn.click(
407
+ fn=None,
408
+ cancels=gen_event,
409
+ queue=False
410
+ )
411
+
412
+ app.launch()