ramimu commited on
Commit
240a407
Β·
verified Β·
1 Parent(s): 5cbfdab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +428 -13
app.py CHANGED
@@ -1,18 +1,433 @@
1
- #!/usr/bin/env python3
2
- """
3
- Chatterbox Voice Cloning - Hugging Face Space
4
- Main entry point for the application
5
- """
6
-
7
- import sys
8
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # Add the voice_cloning directory to the path
11
- sys.path.append(os.path.join(os.path.dirname(__file__), 'voice_cloning'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Import and run the main app
14
- from voice_cloning.app import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  if __name__ == "__main__":
17
- # Run the app
18
- main()
 
1
+ import gradio as gr
 
 
 
 
 
 
2
  import os
3
+ import traceback # For detailed error logging
4
+ import torch
5
+ from huggingface_hub import hf_hub_download
6
+ import shutil
7
+
8
+ # Import configuration
9
+ try:
10
+ from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH
11
+ except ImportError:
12
+ # Fallback configuration if config.py is not found
13
+ MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model"
14
+ LOCAL_MODEL_PATH = "./chatterbox_model_files"
15
+ MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"]
16
+
17
+ # Try importing chatterbox with better error handling
18
+ try:
19
+ from chatterbox.tts import ChatterboxTTS
20
+ chatterbox_available = True
21
+ print("Chatterbox TTS imported successfully")
22
+ except ImportError as e:
23
+ print(f"Failed to import ChatterboxTTS: {e}")
24
+ print("Trying alternative import...")
25
+ try:
26
+ import chatterbox
27
+ from chatterbox import ChatterboxTTS
28
+ chatterbox_available = True
29
+ print("Chatterbox TTS imported with alternative method")
30
+ except ImportError as e2:
31
+ print(f"Alternative import also failed: {e2}")
32
+ chatterbox_available = False
33
+
34
+ # --- Global Model Variable ---
35
+ model = None
36
+
37
+ def download_model_files():
38
+ """Download model files from Hugging Face Hub if they don't exist locally"""
39
+ print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
40
+
41
+ # Create model directory if it doesn't exist
42
+ os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
43
+
44
+ for filename in MODEL_FILES:
45
+ local_path = os.path.join(LOCAL_MODEL_PATH, filename)
46
+ if not os.path.exists(local_path):
47
+ print(f"Downloading {filename} from {MODEL_REPO_ID}...")
48
+ try:
49
+ downloaded_path = hf_hub_download(
50
+ repo_id=MODEL_REPO_ID,
51
+ filename=filename,
52
+ cache_dir="./cache",
53
+ force_download=False # Use cache if available
54
+ )
55
+ # Copy to our local model path
56
+ shutil.copy2(downloaded_path, local_path)
57
+ print(f"βœ“ Downloaded and copied {filename}")
58
+ except Exception as e:
59
+ print(f"βœ— Failed to download {filename}: {e}")
60
+ raise e
61
+ else:
62
+ print(f"βœ“ {filename} already exists locally")
63
+
64
+ print("All model files are ready!")
65
+
66
+ # --- Load the Model ---
67
+ if chatterbox_available:
68
+ print("Downloading model files from Hugging Face Hub...")
69
+ try:
70
+ download_model_files()
71
+ except Exception as e:
72
+ print(f"ERROR: Failed to download model files: {e}")
73
+ print("Model loading will fail without these files.")
74
+
75
+ print(f"Attempting to load Chatterbox model from local directory: {LOCAL_MODEL_PATH}")
76
+ if not os.path.exists(LOCAL_MODEL_PATH):
77
+ print(f"ERROR: Local model directory not found at {LOCAL_MODEL_PATH}")
78
+ print("Please ensure the model files were downloaded successfully.")
79
+ else:
80
+ print(f"Contents of {LOCAL_MODEL_PATH}: {os.listdir(LOCAL_MODEL_PATH)}")
81
+ try:
82
+ # Load the model from the specified local directory
83
+ # Set device to CPU or CUDA if available
84
+ device = "cuda" if torch.cuda.is_available() else "cpu"
85
+ print(f"Using device: {device}")
86
+
87
+ # The correct method signature is from_local(model_path, device)
88
+ # based on the error message showing from_local is called internally
89
+ try:
90
+ model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device)
91
+ print("Chatterbox model loaded successfully using from_local method.")
92
+ except Exception as e1:
93
+ print(f"from_local attempt failed: {e1}")
94
+ try:
95
+ # Try the corrected from_pretrained with proper parameter order
96
+ # It seems from_pretrained expects (local_path, device) not (device=device)
97
+ model = ChatterboxTTS.from_pretrained(LOCAL_MODEL_PATH, device)
98
+ print("Chatterbox model loaded successfully with corrected from_pretrained.")
99
+ except Exception as e2:
100
+ print(f"Corrected from_pretrained failed: {e2}")
101
+ try:
102
+ # Try loading individual components manually
103
+ import torch
104
+ s3gen_path = os.path.join(LOCAL_MODEL_PATH, "s3gen.pt")
105
+ ve_path = os.path.join(LOCAL_MODEL_PATH, "ve.pt")
106
+ tokenizer_path = os.path.join(LOCAL_MODEL_PATH, "tokenizer.json")
107
+ t3_cfg_path = os.path.join(LOCAL_MODEL_PATH, "t3_cfg.pt")
108
+
109
+ print(f"Loading components manually...")
110
+ print(f" s3gen: {s3gen_path}")
111
+ print(f" ve: {ve_path}")
112
+ print(f" tokenizer: {tokenizer_path}")
113
+ print(f" t3_cfg: {t3_cfg_path}")
114
+
115
+ # Load the components
116
+ s3gen = torch.load(s3gen_path, map_location=device)
117
+ ve = torch.load(ve_path, map_location=device)
118
+
119
+ # Load tokenizer
120
+ import json
121
+ with open(tokenizer_path, 'r') as f:
122
+ tokenizer = json.load(f)
123
+
124
+ # Create model instance with loaded components
125
+ model = ChatterboxTTS(s3gen, ve, tokenizer, device)
126
+ print("Chatterbox model loaded successfully with manual component loading.")
127
+ except Exception as e3:
128
+ print(f"Manual loading failed: {e3}")
129
+ raise e3
130
+
131
+ except Exception as e:
132
+ print(f"ERROR: Failed to load Chatterbox model from local directory: {e}")
133
+ print("Detailed error trace:")
134
+ traceback.print_exc() # Prints the full traceback to the Hugging Face Space logs
135
+ model = None # Ensure model is None if loading fails
136
+ else:
137
+ print("ERROR: Chatterbox TTS library not available")
138
+
139
+ def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
140
+ if not chatterbox_available:
141
+ return None, "Error: Chatterbox TTS library not available. Please check installation."
142
+ if model is None:
143
+ return None, "Error: Model not loaded. Please check the logs for details."
144
+ if not text_to_speak or text_to_speak.strip() == "":
145
+ return None, "Error: Please enter some text to speak."
146
+ if reference_audio_path is None:
147
+ return None, "Error: Please upload a reference audio file (.wav or .mp3)."
148
+
149
+ try:
150
+ print(f"Received request:")
151
+ print(f" Text: '{text_to_speak}'")
152
+ print(f" Audio: '{reference_audio_path}'")
153
+ print(f" Exaggeration: {exaggeration}")
154
+ print(f" CFG/Pace: {cfg_pace}")
155
+ print(f" Random Seed: {random_seed}")
156
+ print(f" Temperature: {temperature}")
157
+
158
+ # Set random seed if specified
159
+ if random_seed > 0:
160
+ import torch
161
+ torch.manual_seed(random_seed)
162
+ if torch.cuda.is_available():
163
+ torch.cuda.manual_seed(random_seed)
164
+
165
+ # Use the correct ChatterboxTTS generate method signature with advanced parameters
166
+ output_wav_data = model.generate(
167
+ text=text_to_speak,
168
+ audio_prompt_path=reference_audio_path,
169
+ exaggeration=exaggeration, # Controls how much the voice characteristics are emphasized
170
+ cfg_weight=cfg_pace, # Classifier-free guidance weight (pace)
171
+ temperature=temperature # Controls randomness in generation
172
+ )
173
+
174
+ # Get the sample rate from the model
175
+ try:
176
+ sample_rate = model.sr # ChatterboxTTS uses 'sr' attribute
177
+ except:
178
+ sample_rate = 24000 # Default fallback
179
+
180
+ print(f"Audio generated successfully. Output data type: {type(output_wav_data)}, Sample rate: {sample_rate}")
181
+
182
+ # Handle different output formats
183
+ if isinstance(output_wav_data, str):
184
+ # If it's a file path, return the path
185
+ return output_wav_data, "Success: Audio generated successfully!"
186
+ else:
187
+ # If it's numpy array or tensor, return with sample rate
188
+ import numpy as np
189
+ if hasattr(output_wav_data, 'cpu'):
190
+ # Convert tensor to numpy if needed
191
+ output_wav_data = output_wav_data.cpu().numpy()
192
+
193
+ # Ensure it's the right shape for Gradio (1D array)
194
+ if output_wav_data.ndim > 1:
195
+ output_wav_data = output_wav_data.squeeze()
196
+
197
+ return (sample_rate, output_wav_data), "Success: Audio generated successfully!"
198
+
199
+ except Exception as e:
200
+ print(f"ERROR: Failed during audio generation: {e}")
201
+ print("Detailed error trace for audio generation:")
202
+ traceback.print_exc() # Prints the full traceback
203
+ return None, f"Error during audio generation: {str(e)}. Check logs for more details."
204
+
205
+ # --- API Endpoint Function ---
206
+ def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
207
+ """
208
+ API version of clone_voice that accepts URL or base64 audio data
209
+ """
210
+ import requests
211
+ import tempfile
212
+ import os
213
+ import base64
214
+
215
+ # Handle different audio input formats
216
+ temp_audio_path = None
217
+ try:
218
+ if reference_audio_url.startswith('data:audio'):
219
+ # Handle base64 encoded audio
220
+ header, encoded = reference_audio_url.split(',', 1)
221
+ audio_data = base64.b64decode(encoded)
222
+
223
+ # Determine file extension from MIME type
224
+ if 'mp3' in header:
225
+ ext = '.mp3'
226
+ elif 'wav' in header:
227
+ ext = '.wav'
228
+ else:
229
+ ext = '.wav' # Default
230
+
231
+ # Save to temporary file
232
+ with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
233
+ temp_file.write(audio_data)
234
+ temp_audio_path = temp_file.name
235
+
236
+ elif reference_audio_url.startswith('http'):
237
+ # Download audio from URL
238
+ response = requests.get(reference_audio_url)
239
+ response.raise_for_status()
240
+
241
+ # Determine extension from URL or content type
242
+ if reference_audio_url.endswith('.mp3'):
243
+ ext = '.mp3'
244
+ elif reference_audio_url.endswith('.wav'):
245
+ ext = '.wav'
246
+ else:
247
+ ext = '.wav' # Default
248
+
249
+ with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
250
+ temp_file.write(response.content)
251
+ temp_audio_path = temp_file.name
252
+ else:
253
+ # Assume it's a local file path
254
+ temp_audio_path = reference_audio_url
255
+
256
+ # Call the main clone_voice function
257
+ audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature)
258
+
259
+ # Clean up temporary file if we created one
260
+ if temp_audio_path and temp_audio_path != reference_audio_url:
261
+ try:
262
+ os.unlink(temp_audio_path)
263
+ except:
264
+ pass
265
+
266
+ return audio_output, status
267
+
268
+ except Exception as e:
269
+ if temp_audio_path and temp_audio_path != reference_audio_url:
270
+ try:
271
+ os.unlink(temp_audio_path)
272
+ except:
273
+ pass
274
+ return None, f"API Error: {str(e)}"
275
 
276
+ # --- Define Gradio Interface ---
277
+ # --- Define Gradio Interface ---
278
+ with gr.Blocks(title="Advanced Chatterbox Voice Cloning", theme=gr.themes.Soft()) as iface:
279
+ gr.Markdown("# πŸŽ™οΈ Advanced Chatterbox Voice Cloning")
280
+ gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.")
281
+
282
+ with gr.Row():
283
+ with gr.Column(scale=2):
284
+ # Main inputs
285
+ text_input = gr.Textbox(
286
+ label="Text to Speak",
287
+ placeholder="Enter the text you want the cloned voice to say...",
288
+ lines=3
289
+ )
290
+ audio_input = gr.Audio(
291
+ type="filepath",
292
+ label="Reference Audio (Upload a short .wav or .mp3 clip)",
293
+ sources=["upload", "microphone"]
294
+ )
295
+
296
+ # Advanced controls in an accordion
297
+ with gr.Accordion("πŸ”§ Advanced Settings", open=False):
298
+ with gr.Row():
299
+ exaggeration = gr.Slider(
300
+ minimum=0.25,
301
+ maximum=1.0,
302
+ value=0.6,
303
+ step=0.05,
304
+ label="Exaggeration",
305
+ info="Controls voice characteristic emphasis (0.5 = neutral, higher = more exaggerated)"
306
+ )
307
+ cfg_pace = gr.Slider(
308
+ minimum=0.2,
309
+ maximum=1.0,
310
+ value=0.3,
311
+ step=0.05,
312
+ label="CFG/Pace",
313
+ info="Classifier-free guidance weight (affects generation quality and pace)"
314
+ )
315
+
316
+ with gr.Row():
317
+ random_seed = gr.Number(
318
+ value=0,
319
+ label="Random Seed",
320
+ info="Set to 0 for random results, or use a specific number for reproducible outputs",
321
+ precision=0
322
+ )
323
+ temperature = gr.Slider(
324
+ minimum=0.05,
325
+ maximum=2.0,
326
+ value=0.6,
327
+ step=0.05,
328
+ label="Temperature",
329
+ info="Controls randomness in generation (lower = more consistent, higher = more varied)"
330
+ )
331
+
332
+ # Generate button
333
+ generate_btn = gr.Button("🎡 Generate Voice Clone", variant="primary", size="lg")
334
+
335
+ with gr.Column(scale=1):
336
+ # Outputs
337
+ audio_output = gr.Audio(
338
+ label="Generated Audio",
339
+ type="numpy",
340
+ interactive=False
341
+ )
342
+ status_output = gr.Textbox(
343
+ label="Status",
344
+ interactive=False,
345
+ lines=2
346
+ )
347
+
348
+ # API Information
349
+ with gr.Accordion("πŸ”Œ API Usage", open=False):
350
+ gr.Markdown("""
351
+ ### Using this as an API endpoint
352
+
353
+ You can use this Hugging Face Space as an API endpoint in your applications:
354
+
355
+ **Endpoint URL:** `https://your-username-voice-cloning.hf.space/api/predict`
356
+
357
+ **Example Python code:**
358
+ ```python
359
+ import requests
360
+ import base64
361
+
362
+ # Encode your audio file
363
+ with open("reference_audio.wav", "rb") as f:
364
+ audio_data = base64.b64encode(f.read()).decode()
365
+ audio_url = f"data:audio/wav;base64,{audio_data}"
366
+
367
+ # API request
368
+ response = requests.post(
369
+ "https://your-username-voice-cloning.hf.space/api/predict",
370
+ json={
371
+ "data": [
372
+ "Hello, this is my cloned voice!", # text
373
+ audio_url, # reference audio (base64 or URL)
374
+ 0.6, # exaggeration
375
+ 0.3, # cfg_pace
376
+ 0, # random_seed
377
+ 0.6 # temperature
378
+ ]
379
+ }
380
+ )
381
+ ```
382
+
383
+ **Parameters:**
384
+ - `text_to_speak`: Text to synthesize
385
+ - `reference_audio`: Base64 encoded audio or URL
386
+ - `exaggeration`: Voice emphasis (0.25-1.0, default: 0.6)
387
+ - `cfg_pace`: Generation guidance (0.2-1.0, default: 0.3)
388
+ - `random_seed`: Reproducibility seed (0 for random, default: 0)
389
+ - `temperature`: Generation randomness (0.05-2.0, default: 0.6)
390
+ """)
391
+
392
+ # Examples
393
+ with gr.Accordion("πŸ“ Examples", open=False):
394
+ gr.Examples(
395
+ examples=[
396
+ ["Hello, this is a test of the voice cloning system.", None, 0.5, 0.5, 0, 0.8],
397
+ ["The quick brown fox jumps over the lazy dog.", None, 0.7, 0.3, 42, 0.6],
398
+ ["Welcome to our AI voice cloning service. We hope you enjoy the experience!", None, 0.4, 0.7, 123, 1.0]
399
+ ],
400
+ inputs=[text_input, audio_input, exaggeration, cfg_pace, random_seed, temperature],
401
+ outputs=[audio_output, status_output],
402
+ fn=clone_voice,
403
+ cache_examples=False
404
+ )
405
+
406
+ # Connect the generate button
407
+ generate_btn.click(
408
+ fn=clone_voice,
409
+ inputs=[text_input, audio_input, exaggeration, cfg_pace, random_seed, temperature],
410
+ outputs=[audio_output, status_output],
411
+ api_name="clone_voice" # This enables API access
412
+ )
413
 
414
+ # --- Launch the Gradio App ---
415
+ def main():
416
+ print("Starting Advanced Gradio interface...")
417
+ # Launch with specific configuration for API access and avoid manifest issues
418
+ iface.launch(
419
+ server_name="0.0.0.0", # Allow external connections
420
+ server_port=7860, # Explicit port
421
+ show_error=True, # Show detailed errors
422
+ quiet=False, # Show startup logs
423
+ favicon_path=None, # Disable favicon to avoid 404
424
+ share=False, # Set to True if you want a public link
425
+ auth=None, # Add authentication if needed: ("username", "password")
426
+ app_kwargs={
427
+ "docs_url": "/docs", # Enable API docs at /docs
428
+ "redoc_url": "/redoc" # Enable alternative docs at /redoc
429
+ }
430
+ )
431
 
432
  if __name__ == "__main__":
433
+ main()