ramimu commited on
Commit
74dbc75
Β·
verified Β·
1 Parent(s): 91d6893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -55
app.py CHANGED
@@ -22,27 +22,50 @@ except ImportError as e:
22
  print(f"Failed to import ChatterboxTTS: {e}")
23
  chatterbox_available = False
24
 
 
25
  model = None
 
26
 
27
- def cleanup_gpu_memory():
28
- """Clean up GPU memory to prevent CUDA errors."""
29
- if torch.cuda.is_available():
30
- torch.cuda.empty_cache()
31
- torch.cuda.synchronize()
32
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- def safe_load_model():
35
- """Safely load the model with proper error handling."""
36
- global model
 
 
 
37
 
38
  if not chatterbox_available:
39
  print("ERROR: Chatterbox TTS library not available")
40
  return False
41
 
42
  try:
43
- # Clean up any existing GPU memory
44
- cleanup_gpu_memory()
45
 
 
46
  device = "cuda" if torch.cuda.is_available() else "cpu"
47
  print(f"Loading model on device: {device}")
48
 
@@ -65,16 +88,16 @@ def safe_load_model():
65
  model = model.to(device)
66
  if model and hasattr(model, 'eval'):
67
  model.eval()
68
-
69
- # Clean up after loading
70
- cleanup_gpu_memory()
71
  return True
72
 
73
  except Exception as e:
74
- print(f"ERROR: Failed to load model: {e}")
75
  traceback.print_exc()
76
  model = None
77
- cleanup_gpu_memory()
78
  return False
79
 
80
  def load_model_manually(device):
@@ -85,7 +108,7 @@ def load_model_manually(device):
85
  model_path = pathlib.Path(LOCAL_MODEL_PATH)
86
  print("Manual loading with correct constructor signature...")
87
 
88
- # Load components to CPU first
89
  s3gen_path = model_path / "s3gen.pt"
90
  ve_path = model_path / "ve.pt"
91
  tokenizer_path = model_path / "tokenizer.json"
@@ -116,54 +139,46 @@ def load_model_manually(device):
116
  print("βœ“ Model loaded successfully with manual constructor.")
117
  return model
118
 
119
- def download_model_files():
120
- """Download model files with error handling."""
121
- print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
122
- os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
123
-
124
- for filename in MODEL_FILES:
125
- local_path = os.path.join(LOCAL_MODEL_PATH, filename)
126
- if not os.path.exists(local_path):
127
- print(f"Downloading {filename} from {MODEL_REPO_ID}...")
128
- try:
129
- downloaded_path = hf_hub_download(
130
- repo_id=MODEL_REPO_ID,
131
- filename=filename,
132
- cache_dir="./cache",
133
- force_download=False
134
- )
135
- shutil.copy2(downloaded_path, local_path)
136
- print(f"βœ“ Downloaded and copied {filename}")
137
- except Exception as e:
138
- print(f"βœ— Failed to download {filename}: {e}")
139
- raise e
140
- else:
141
- print(f"βœ“ {filename} already exists locally")
142
- print("All model files are ready!")
143
 
144
- # Initialize model
145
  if chatterbox_available:
146
  try:
147
  download_model_files()
148
- safe_load_model()
149
  except Exception as e:
150
- print(f"ERROR during initialization: {e}")
151
 
152
  @spaces.GPU
153
  def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
154
- """Main voice cloning function with improved error handling."""
 
155
 
156
  # Input validation
157
  if not chatterbox_available:
158
  return None, "Error: Chatterbox TTS library not available. Please check installation."
159
- if model is None:
160
- return None, "Error: Model not loaded. Please check the logs for details."
161
  if not text_to_speak or text_to_speak.strip() == "":
162
  return None, "Error: Please enter some text to speak."
 
163
  if reference_audio_path is None:
164
  return None, "Error: Please upload a reference audio file (.wav or .mp3)."
165
 
166
  try:
 
 
 
 
 
 
 
 
 
167
  print(f"Processing request:")
168
  print(f" Text length: {len(text_to_speak)} characters")
169
  print(f" Audio: '{reference_audio_path}'")
@@ -178,13 +193,13 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
178
  if torch.cuda.is_available():
179
  torch.cuda.manual_seed(random_seed)
180
 
181
- # Check CUDA availability before generation
182
  if torch.cuda.is_available():
183
  print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
184
 
185
  # Generate audio with error handling
186
  try:
187
- with torch.no_grad(): # Disable gradient computation
188
  output_wav_data = model.generate(
189
  text=text_to_speak,
190
  audio_prompt_path=reference_audio_path,
@@ -209,6 +224,7 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
209
  print("βœ“ Recovery successful after memory cleanup")
210
  except Exception as retry_error:
211
  print(f"βœ— Recovery failed: {retry_error}")
 
212
  return None, f"CUDA error: {str(e)}. GPU memory issue - please try again in a moment."
213
  else:
214
  raise e
@@ -244,7 +260,10 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
244
  traceback.print_exc()
245
 
246
  # Clean up on error
247
- cleanup_gpu_memory()
 
 
 
248
 
249
  # Provide specific error messages
250
  error_msg = str(e)
@@ -256,7 +275,7 @@ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=
256
  return None, f"Error during audio generation: {error_msg}. Check logs for more details."
257
 
258
  def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
259
- """API wrapper with improved error handling."""
260
  import requests
261
  import tempfile
262
  import os
@@ -282,7 +301,7 @@ def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pa
282
  else:
283
  temp_audio_path = reference_audio_url
284
 
285
- # Generate audio
286
  audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature)
287
 
288
  return audio_output, status
@@ -298,11 +317,91 @@ def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pa
298
  except:
299
  pass
300
 
301
- # Rest of your Gradio interface code remains the same...
302
  def main():
303
  print("Starting Advanced Gradio interface...")
304
- # Your existing Gradio interface code here
305
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
  if __name__ == "__main__":
308
  main()
 
22
  print(f"Failed to import ChatterboxTTS: {e}")
23
  chatterbox_available = False
24
 
25
+ # Global model variable - will be loaded inside GPU function
26
  model = None
27
+ model_loaded = False
28
 
29
+ def download_model_files():
30
+ """Download model files with error handling."""
31
+ print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
32
+ os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
33
+
34
+ for filename in MODEL_FILES:
35
+ local_path = os.path.join(LOCAL_MODEL_PATH, filename)
36
+ if not os.path.exists(local_path):
37
+ print(f"Downloading {filename} from {MODEL_REPO_ID}...")
38
+ try:
39
+ downloaded_path = hf_hub_download(
40
+ repo_id=MODEL_REPO_ID,
41
+ filename=filename,
42
+ cache_dir="./cache",
43
+ force_download=False
44
+ )
45
+ shutil.copy2(downloaded_path, local_path)
46
+ print(f"βœ“ Downloaded and copied {filename}")
47
+ except Exception as e:
48
+ print(f"βœ— Failed to download {filename}: {e}")
49
+ raise e
50
+ else:
51
+ print(f"βœ“ {filename} already exists locally")
52
+ print("All model files are ready!")
53
 
54
+ def load_model_on_gpu():
55
+ """Load model inside GPU context - only called within @spaces.GPU decorated function."""
56
+ global model, model_loaded
57
+
58
+ if model_loaded and model is not None:
59
+ return True
60
 
61
  if not chatterbox_available:
62
  print("ERROR: Chatterbox TTS library not available")
63
  return False
64
 
65
  try:
66
+ print("Loading model inside GPU context...")
 
67
 
68
+ # Now we can safely use CUDA operations
69
  device = "cuda" if torch.cuda.is_available() else "cpu"
70
  print(f"Loading model on device: {device}")
71
 
 
88
  model = model.to(device)
89
  if model and hasattr(model, 'eval'):
90
  model.eval()
91
+
92
+ model_loaded = True
93
+ print("βœ“ Model loaded successfully in GPU context")
94
  return True
95
 
96
  except Exception as e:
97
+ print(f"ERROR: Failed to load model in GPU context: {e}")
98
  traceback.print_exc()
99
  model = None
100
+ model_loaded = False
101
  return False
102
 
103
  def load_model_manually(device):
 
108
  model_path = pathlib.Path(LOCAL_MODEL_PATH)
109
  print("Manual loading with correct constructor signature...")
110
 
111
+ # Load components to CPU first, then move to device
112
  s3gen_path = model_path / "s3gen.pt"
113
  ve_path = model_path / "ve.pt"
114
  tokenizer_path = model_path / "tokenizer.json"
 
139
  print("βœ“ Model loaded successfully with manual constructor.")
140
  return model
141
 
142
+ def cleanup_gpu_memory():
143
+ """Clean up GPU memory - only call within GPU context."""
144
+ if torch.cuda.is_available():
145
+ torch.cuda.empty_cache()
146
+ torch.cuda.synchronize()
147
+ gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ # Download model files during startup (CPU only)
150
  if chatterbox_available:
151
  try:
152
  download_model_files()
153
+ print("Model files downloaded. Model will be loaded on first GPU request.")
154
  except Exception as e:
155
+ print(f"ERROR during model file download: {e}")
156
 
157
  @spaces.GPU
158
  def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
159
+ """Main voice cloning function - runs on GPU."""
160
+ global model, model_loaded
161
 
162
  # Input validation
163
  if not chatterbox_available:
164
  return None, "Error: Chatterbox TTS library not available. Please check installation."
165
+
 
166
  if not text_to_speak or text_to_speak.strip() == "":
167
  return None, "Error: Please enter some text to speak."
168
+
169
  if reference_audio_path is None:
170
  return None, "Error: Please upload a reference audio file (.wav or .mp3)."
171
 
172
  try:
173
+ # Load model if not already loaded (inside GPU context)
174
+ if not model_loaded:
175
+ print("Loading model for the first time...")
176
+ if not load_model_on_gpu():
177
+ return None, "Error: Failed to load model. Please check the logs for details."
178
+
179
+ if model is None:
180
+ return None, "Error: Model not loaded. Please check the logs for details."
181
+
182
  print(f"Processing request:")
183
  print(f" Text length: {len(text_to_speak)} characters")
184
  print(f" Audio: '{reference_audio_path}'")
 
193
  if torch.cuda.is_available():
194
  torch.cuda.manual_seed(random_seed)
195
 
196
+ # Check CUDA availability and memory
197
  if torch.cuda.is_available():
198
  print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
199
 
200
  # Generate audio with error handling
201
  try:
202
+ with torch.no_grad(): # Disable gradient computation to save memory
203
  output_wav_data = model.generate(
204
  text=text_to_speak,
205
  audio_prompt_path=reference_audio_path,
 
224
  print("βœ“ Recovery successful after memory cleanup")
225
  except Exception as retry_error:
226
  print(f"βœ— Recovery failed: {retry_error}")
227
+ cleanup_gpu_memory()
228
  return None, f"CUDA error: {str(e)}. GPU memory issue - please try again in a moment."
229
  else:
230
  raise e
 
260
  traceback.print_exc()
261
 
262
  # Clean up on error
263
+ try:
264
+ cleanup_gpu_memory()
265
+ except:
266
+ pass
267
 
268
  # Provide specific error messages
269
  error_msg = str(e)
 
275
  return None, f"Error during audio generation: {error_msg}. Check logs for more details."
276
 
277
  def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
278
+ """API wrapper function - this will call the GPU function."""
279
  import requests
280
  import tempfile
281
  import os
 
301
  else:
302
  temp_audio_path = reference_audio_url
303
 
304
+ # Call the GPU function
305
  audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature)
306
 
307
  return audio_output, status
 
317
  except:
318
  pass
319
 
320
+ # Your existing Gradio interface code goes here...
321
  def main():
322
  print("Starting Advanced Gradio interface...")
323
+
324
+ # Your existing Gradio interface code
325
+ with gr.Blocks(title="πŸŽ™οΈ Advanced Chatterbox Voice Cloning") as demo:
326
+ gr.Markdown("# πŸŽ™οΈ Advanced Chatterbox Voice Cloning")
327
+ gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.")
328
+
329
+ with gr.Row():
330
+ with gr.Column(scale=2):
331
+ text_input = gr.Textbox(
332
+ label="Text to Speak",
333
+ placeholder="Enter the text you want the cloned voice to say...",
334
+ lines=3
335
+ )
336
+ audio_input = gr.Audio(
337
+ type="filepath",
338
+ label="Reference Audio (Upload a short .wav or .mp3 clip)",
339
+ sources=["upload", "microphone"]
340
+ )
341
+
342
+ with gr.Accordion("πŸ”§ Advanced Settings", open=False):
343
+ with gr.Row():
344
+ exaggeration_input = gr.Slider(
345
+ minimum=0.25, maximum=1.0, value=0.6, step=0.05,
346
+ label="Exaggeration", info="Controls voice characteristic emphasis"
347
+ )
348
+ cfg_pace_input = gr.Slider(
349
+ minimum=0.2, maximum=1.0, value=0.3, step=0.05,
350
+ label="CFG/Pace", info="Classifier-free guidance weight"
351
+ )
352
+ with gr.Row():
353
+ seed_input = gr.Number(
354
+ value=0, label="Random Seed", info="Set to 0 for random results", precision=0
355
+ )
356
+ temperature_input = gr.Slider(
357
+ minimum=0.05, maximum=2.0, value=0.6, step=0.05,
358
+ label="Temperature", info="Controls randomness in generation"
359
+ )
360
+
361
+ generate_btn = gr.Button("🎡 Generate Voice Clone", variant="primary", size="lg")
362
+
363
+ with gr.Column(scale=1):
364
+ audio_output = gr.Audio(label="Generated Audio", type="numpy")
365
+ status_output = gr.Textbox(label="Status", lines=2)
366
+
367
+ # Connect the interface
368
+ generate_btn.click(
369
+ fn=clone_voice_api,
370
+ inputs=[text_input, audio_input, exaggeration_input, cfg_pace_input, seed_input, temperature_input],
371
+ outputs=[audio_output, status_output],
372
+ api_name="predict"
373
+ )
374
+
375
+ # API endpoint for external calls
376
+ def clone_voice_base64_api(text_to_speak, reference_audio_b64, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
377
+ return clone_voice_api(text_to_speak, reference_audio_b64, exaggeration, cfg_pace, random_seed, temperature)
378
+
379
+ # Hidden API interface
380
+ with gr.Row(visible=False):
381
+ api_text_input = gr.Textbox()
382
+ api_audio_input = gr.Textbox()
383
+ api_exaggeration_input = gr.Slider(minimum=0.25, maximum=1.0, value=0.6)
384
+ api_cfg_pace_input = gr.Slider(minimum=0.2, maximum=1.0, value=0.3)
385
+ api_seed_input = gr.Number(value=0, precision=0)
386
+ api_temperature_input = gr.Slider(minimum=0.05, maximum=2.0, value=0.6)
387
+ api_audio_output = gr.Audio(type="numpy")
388
+ api_status_output = gr.Textbox()
389
+ api_btn = gr.Button()
390
+
391
+ api_btn.click(
392
+ fn=clone_voice_base64_api,
393
+ inputs=[api_text_input, api_audio_input, api_exaggeration_input, api_cfg_pace_input, api_seed_input, api_temperature_input],
394
+ outputs=[api_audio_output, api_status_output],
395
+ api_name="clone_voice"
396
+ )
397
+
398
+ demo.launch(
399
+ server_name="0.0.0.0",
400
+ server_port=7860,
401
+ show_error=True,
402
+ quiet=False,
403
+ share=False
404
+ )
405
 
406
  if __name__ == "__main__":
407
  main()