suprimedev commited on
Commit
802512e
·
verified ·
1 Parent(s): 4e79e7e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +404 -0
app.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import shutil
5
+ import subprocess
6
+ import requests
7
+ import tarfile
8
+ from pathlib import Path
9
+ import soundfile as sf
10
+ import sherpa_onnx
11
+ import numpy as np
12
+ import uuid
13
+
14
+ # List of available TTS models
15
+ MODELS = [
16
+ ['mms fa', 'https://huggingface.co/willwade/mms-tts-multilingual-models-onnx/resolve/main/fas', "🌠 راد", 'https://huggingface.co/facebook/mms-tts-fas'],
17
+ ['coqui-vits-female1-karim23657', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/persian-tts-female1-vits-coqui', "🌺 نگار", 'https://huggingface.co/Kamtera/persian-tts-female1-vits'],
18
+ ['coqui-vits-male1-karim23657', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/persian-tts-male1-vits-coqui', "🌟 آرش", 'https://huggingface.co/Kamtera/persian-tts-male1-vits'],
19
+ ['coqui-vits-male-karim23657', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/male-male-coqui-vits', "🦁 کیان", 'https://huggingface.co/Kamtera/persian-tts-male-vits'],
20
+ ['coqui-vits-female-karim23657', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/female-female-coqui-vits', "🌷 مهتاب", 'https://huggingface.co/Kamtera/persian-tts-female-vits'],
21
+ ['coqui-vits-female-GPTInformal-karim23657', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/female-GPTInformal-coqui-vits', "🌼 شیوا", 'https://huggingface.co/karim23657/persian-tts-female-GPTInformal-Persian-vits'],
22
+ ['coqui-vits-male-SmartGitiCorp', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/male-SmartGitiCorp-coqui-vits', "🚀 بهمن", 'https://huggingface.co/SmartGitiCorp/persian_tts_vits'],
23
+ ['vits-piper-fa-ganji', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/vits-piper-fa-ganji', "🚀 برنا", 'https://huggingface.co/SadeghK/persian-text-to-speech'],
24
+ ['vits-piper-fa-ganji-adabi', 'https://huggingface.co/karim23657/persian-tts-vits/tree/main/vits-piper-fa-ganji-adabi', "🚀 برنا-1", 'https://huggingface.co/SadeghK/persian-text-to-speech'],
25
+ ['vits-piper-fa-gyro-medium', 'https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-fa_IR-gyro-medium.tar.bz2', "💧 نیما", 'https://huggingface.co/gyroing/Persian-Piper-Model-gyro'],
26
+ ['piper-fa-amir-medium', 'https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-fa_IR-amir-medium.tar.bz2', "⚡️ آریا", 'https://huggingface.co/SadeghK/persian-text-to-speech'],
27
+ ['vits-mimic3-fa-haaniye_low', 'https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-mimic3-fa-haaniye_low.tar.bz2', "🌹 ریما", 'https://github.com/MycroftAI/mimic3'],
28
+ ['vits-piper-fa_en-rezahedayatfar-ibrahimwalk-medium', 'https://github.com/k2-fsa/sherpa-onnx/releases/download/tts-models/vits-piper-fa_en-rezahedayatfar-ibrahimwalk-medium.tar.bz2', "🌠 پیام", 'https://huggingface.co/mah92/persian-english-piper-tts-model'],
29
+ ]
30
+
31
+ def download_and_extract_model(url, destination):
32
+ """Download and extract the model files."""
33
+ print(f"Downloading from URL: {url}")
34
+ print(f"Destination: {destination}")
35
+
36
+ # Convert Hugging Face URL format if needed
37
+ if "huggingface.co" in url:
38
+ base_url = url.replace("/tree/main/", "/resolve/main/")
39
+ model_id = base_url.split("/")[-1]
40
+
41
+ # Check if this is an MMS model
42
+ is_mms_model = True
43
+
44
+ if is_mms_model:
45
+ # MMS models have both model.onnx and tokens.txt
46
+ model_url = f"{base_url}/model.onnx"
47
+ tokens_url = f"{base_url}/tokens.txt"
48
+
49
+ # Download model.onnx
50
+ print("Downloading model.onnx...")
51
+ model_path = os.path.join(destination, "model.onnx")
52
+ response = requests.get(model_url, stream=True)
53
+ if response.status_code != 200:
54
+ raise Exception(f"Failed to download model from {model_url}. Status code: {response.status_code}")
55
+
56
+ total_size = int(response.headers.get('content-length', 0))
57
+ block_size = 8192
58
+ downloaded = 0
59
+
60
+ print(f"Total size: {total_size / (1024*1024):.1f} MB")
61
+ with open(model_path, "wb") as f:
62
+ for chunk in response.iter_content(chunk_size=block_size):
63
+ if chunk:
64
+ f.write(chunk)
65
+ downloaded += len(chunk)
66
+ if total_size > 0:
67
+ percent = int((downloaded / total_size) * 100)
68
+ if percent % 10 == 0:
69
+ print(f" {percent}%", end="", flush=True)
70
+ print("\nModel download complete")
71
+
72
+ # Download tokens.txt
73
+ print("Downloading tokens.txt...")
74
+ tokens_path = os.path.join(destination, "tokens.txt")
75
+ response = requests.get(tokens_url, stream=True)
76
+ if response.status_code != 200:
77
+ raise Exception(f"Failed to download tokens from {tokens_url}. Status code: {response.status_code}")
78
+
79
+ with open(tokens_path, "wb") as f:
80
+ f.write(response.content)
81
+ print("Tokens download complete")
82
+
83
+ return
84
+ else:
85
+ # Other models are stored as tar.bz2 files
86
+ url = f"{base_url}.tar.bz2"
87
+
88
+ # Try the URL
89
+ response = requests.get(url, stream=True)
90
+ if response.status_code != 200:
91
+ raise Exception(f"Failed to download model from {url}. Status code: {response.status_code}")
92
+
93
+ # Check if this is a Git LFS file pointer
94
+ content_start = response.content[:100].decode('utf-8', errors='ignore')
95
+ if content_start.startswith('version https://git-lfs.github.com/spec/v1'):
96
+ raise Exception(f"Received Git LFS pointer instead of file content from {url}")
97
+
98
+ # Create model directory if it doesn't exist
99
+ os.makedirs(destination, exist_ok=True)
100
+
101
+ # For non-MMS models, handle tar.bz2 files
102
+ tar_path = os.path.join(destination, "model.tar.bz2")
103
+
104
+ # Download the file
105
+ print("Downloading model archive...")
106
+ response = requests.get(url, stream=True)
107
+ total_size = int(response.headers.get('content-length', 0))
108
+ block_size = 8192
109
+ downloaded = 0
110
+
111
+ print(f"Total size: {total_size / (1024*1024):.1f} MB")
112
+ with open(tar_path, "wb") as f:
113
+ for chunk in response.iter_content(chunk_size=block_size):
114
+ if chunk:
115
+ f.write(chunk)
116
+ downloaded += len(chunk)
117
+ if total_size > 0:
118
+ percent = int((downloaded / total_size) * 100)
119
+ if percent % 10 == 0:
120
+ print(f" {percent}%", end="", flush=True)
121
+ print("\nDownload complete")
122
+
123
+ # Extract the tar.bz2 file
124
+ print(f"Extracting {tar_path} to {destination}")
125
+ try:
126
+ with tarfile.open(tar_path, "r:bz2") as tar:
127
+ tar.extractall(path=destination)
128
+ os.remove(tar_path)
129
+ print("Extraction complete")
130
+ except Exception as e:
131
+ print(f"Error during extraction: {str(e)}")
132
+ raise
133
+
134
+ print("Contents of destination directory:")
135
+ for root, dirs, files in os.walk(destination):
136
+ print(f"\nDirectory: {root}")
137
+ if dirs:
138
+ print(" Subdirectories:", dirs)
139
+ if files:
140
+ print(" Files:", files)
141
+
142
+ def find_model_files(model_dir):
143
+ """Find model files in the given directory and its subdirectories."""
144
+ model_files = {}
145
+
146
+ # Check if this is an MMS model
147
+ is_mms = True
148
+
149
+ for root, _, files in os.walk(model_dir):
150
+ for file in files:
151
+ file_path = os.path.join(root, file)
152
+
153
+ # Model file
154
+ if file.endswith('.onnx'):
155
+ model_files['model'] = file_path
156
+
157
+ # Tokens file
158
+ elif file == 'tokens.txt':
159
+ model_files['tokens'] = file_path
160
+
161
+ # Lexicon file (only for non-MMS models)
162
+ elif file == 'lexicon.txt' and not is_mms:
163
+ model_files['lexicon'] = file_path
164
+
165
+ # Create empty lexicon file if needed (only for non-MMS models)
166
+ if not is_mms and 'model' in model_files and 'lexicon' not in model_files:
167
+ model_dir = os.path.dirname(model_files['model'])
168
+ lexicon_path = os.path.join(model_dir, 'lexicon.txt')
169
+ with open(lexicon_path, 'w', encoding='utf-8') as f:
170
+ pass # Create empty file
171
+ model_files['lexicon'] = lexicon_path
172
+
173
+ return model_files if 'model' in model_files else {}
174
+
175
+ def generate_audio(text, model_info):
176
+ """Generate audio from text using the specified model."""
177
+ try:
178
+ model_dir = os.path.join("./models", model_info)
179
+
180
+ print(f"\nLooking for model in: {model_dir}")
181
+
182
+ # Download model if it doesn't exist
183
+ if not os.path.exists(model_dir):
184
+ print(f"Model directory doesn't exist, downloading {model_info}...")
185
+ os.makedirs(model_dir, exist_ok=True)
186
+ model_url = None
187
+ for model in MODELS:
188
+ if model_info == model[2]:
189
+ model_url = model[1]
190
+ break
191
+ if not model_url:
192
+ raise ValueError(f"Model {model_info} not found in the model list")
193
+
194
+ download_and_extract_model(model_url, model_dir)
195
+
196
+ print(f"Contents of {model_dir}:")
197
+ for item in os.listdir(model_dir):
198
+ item_path = os.path.join(model_dir, item)
199
+ if os.path.isdir(item_path):
200
+ print(f" Directory: {item}")
201
+ print(f" Contents: {os.listdir(item_path)}")
202
+ else:
203
+ print(f" File: {item}")
204
+
205
+ # Find and validate model files
206
+ model_files = find_model_files(model_dir)
207
+ if not model_files or 'model' not in model_files:
208
+ raise ValueError(f"Could not find required model files in {model_dir}")
209
+
210
+ print("\nFound model files:")
211
+ print(f"Model: {model_files['model']}")
212
+ print(f"Tokens: {model_files.get('tokens', 'Not found')}")
213
+ print(f"Lexicon: {model_files.get('lexicon', 'Not required for MMS')}\n")
214
+
215
+ # Check if this is an MMS model
216
+ is_mms = 'mms' in os.path.basename(model_dir).lower()
217
+
218
+ # Create configuration based on model type
219
+ if is_mms:
220
+ if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
221
+ raise ValueError("tokens.txt is required for MMS models")
222
+
223
+ # MMS models use tokens.txt and no lexicon
224
+ vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
225
+ model_files['model'], # model
226
+ '', # lexicon
227
+ model_files['tokens'], # tokens
228
+ '', # data_dir
229
+ '', # dict_dir
230
+ 0.667, # noise_scale
231
+ 0.8, # noise_scale_w
232
+ 1.0 # length_scale
233
+ )
234
+ else:
235
+ # Non-MMS models use lexicon.txt
236
+ if 'tokens' not in model_files or not os.path.exists(model_files['tokens']):
237
+ raise ValueError("tokens.txt is required for VITS models")
238
+
239
+ # Set data dir if it exists
240
+ espeak_data = os.path.join(os.path.dirname(model_files['model']), 'espeak-ng-data')
241
+ data_dir = espeak_data if os.path.exists(espeak_data) else 'espeak-ng-data'
242
+
243
+ # Get lexicon path if it exists
244
+ lexicon = model_files.get('lexicon', '') if os.path.exists(model_files.get('lexicon', '')) else ''
245
+
246
+ # Create VITS model config
247
+ vits_config = sherpa_onnx.OfflineTtsVitsModelConfig(
248
+ model_files['model'], # model
249
+ lexicon, # lexicon
250
+ model_files['tokens'], # tokens
251
+ data_dir, # data_dir
252
+ '', # dict_dir
253
+ 0.667, # noise_scale
254
+ 0.8, # noise_scale_w
255
+ 1.0 # length_scale
256
+ )
257
+
258
+ # Create the model config with VITS
259
+ model_config = sherpa_onnx.OfflineTtsModelConfig()
260
+ model_config.vits = vits_config
261
+
262
+ # Create TTS configuration
263
+ config = sherpa_onnx.OfflineTtsConfig(
264
+ model=model_config,
265
+ max_num_sentences=2
266
+ )
267
+
268
+ # Initialize TTS engine
269
+ tts = sherpa_onnx.OfflineTts(config)
270
+
271
+ # Generate audio
272
+ audio_data = tts.generate(text)
273
+
274
+ # Ensure we have valid audio data
275
+ if audio_data is None or len(audio_data.samples) == 0:
276
+ raise ValueError("Failed to generate audio - no data generated")
277
+
278
+ # Convert samples list to numpy array and normalize
279
+ audio_array = np.array(audio_data.samples, dtype=np.float32)
280
+ if np.any(audio_array): # Check if array is not all zeros
281
+ audio_array = audio_array / np.abs(audio_array).max()
282
+ else:
283
+ raise ValueError("Generated audio is empty")
284
+
285
+ # Return audio array and sample rate
286
+ return (audio_array, audio_data.sample_rate)
287
+
288
+ except Exception as e:
289
+ error_msg = str(e)
290
+ # Check for OOV or token conversion errors
291
+ if "out of vocabulary" in error_msg.lower() or "token" in error_msg.lower():
292
+ error_msg = f"Text contains unsupported characters: {error_msg}"
293
+ print(f"Error generating audio: {error_msg}")
294
+ raise
295
+
296
+ def tts_interface(selected_model, text):
297
+ """Gradio interface for Persian text-to-speech."""
298
+ try:
299
+ if not text.strip():
300
+ return None, "لطفا متنی برای تبدیل به گفتار وارد کنید"
301
+
302
+ # Store original text for status message
303
+ original_text = text
304
+
305
+ try:
306
+ # Update status with language info
307
+ voice_name = selected_model
308
+
309
+ # Generate audio
310
+ audio_data, sample_rate = generate_audio(text, voice_name)
311
+
312
+ # Create audio file
313
+ audio_filename = f"tts_output_{uuid.uuid4()}.wav"
314
+ sf.write(audio_filename, audio_data, samplerate=sample_rate, subtype="PCM_16")
315
+
316
+ # Get model URL for display
317
+ model_url = ""
318
+ for model in MODELS:
319
+ if selected_model == model[2]:
320
+ model_url = model[3]
321
+ break
322
+
323
+ status = f"مدل: {selected_model}\nمنبع مدل: {model_url}\nمتن: '{text}'"
324
+
325
+ return audio_filename, status
326
+
327
+ except ValueError as e:
328
+ # Handle known errors with user-friendly messages
329
+ error_msg = str(e)
330
+ if "cannot process some words" in error_msg.lower():
331
+ return None, error_msg
332
+ return None, f"خطا: {error_msg}"
333
+
334
+ except Exception as e:
335
+ print(f"Error in TTS generation: {str(e)}")
336
+ error_msg = str(e)
337
+ return None, f"خطا: {error_msg}"
338
+
339
+ def create_gradio_interface():
340
+ """Create the Gradio interface."""
341
+ # Prepare voice options from models
342
+ voice_options = [model[2] for model in MODELS]
343
+
344
+ # Create Gradio interface
345
+ with gr.Blocks(title="تبدیل متن به گفتار فارسی", theme=gr.themes.Soft()) as demo:
346
+ gr.Markdown("""
347
+ # تبدیل متن به گفتار فارسی
348
+ با استفاده از مدل‌های مختلف متن را به گفتار تبدیل کنید
349
+ """)
350
+
351
+ with gr.Row():
352
+ with gr.Column():
353
+ text_input = gr.TextArea(
354
+ label="متن فارسی",
355
+ placeholder="متن خود را اینجا وارد کنید...",
356
+ lines=5
357
+ )
358
+
359
+ voice_dropdown = gr.Dropdown(
360
+ label="صدا",
361
+ choices=voice_options,
362
+ value=voice_options[0]
363
+ )
364
+
365
+ generate_button = gr.Button("تبدیل به گفتار")
366
+
367
+ with gr.Column():
368
+ audio_output = gr.Audio(
369
+ label="خروجی صوتی",
370
+ interactive=False
371
+ )
372
+
373
+ status_output = gr.Textbox(
374
+ label="وضعیت",
375
+ interactive=False
376
+ )
377
+
378
+ generate_button.click(
379
+ fn=tts_interface,
380
+ inputs=[voice_dropdown, text_input],
381
+ outputs=[audio_output, status_output]
382
+ )
383
+
384
+ gr.Examples(
385
+ examples=[
386
+ ["سلام. این یک نمونه متن برای نمایش سیستم تبدیل متن به گفتار فارسی است.", voice_options[0]],
387
+ ["تبدیل متن به گفتار یکی از کاربردهای مهم پردازش زبان طبیعی است.", voice_options[1]],
388
+ ["این پروژه از مدل‌های متنوعی برای تولید صدای طبیعی استفاده می‌کند.", voice_options[5]]
389
+ ],
390
+ inputs=[text_input, voice_dropdown],
391
+ outputs=[audio_output, status_output],
392
+ fn=tts_interface,
393
+ cache_examples=False
394
+ )
395
+
396
+ return demo
397
+
398
+ if __name__ == "__main__":
399
+ # Create models directory if it doesn't exist
400
+ os.makedirs("models", exist_ok=True)
401
+
402
+ # Launch Gradio interface
403
+ demo = create_gradio_interface()
404
+ demo.launch(server_name="0.0.0.0", server_port=7860)