import gradio as gr import os import traceback import torch import gc from huggingface_hub import hf_hub_download import shutil import spaces try: from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH except ImportError: MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model" LOCAL_MODEL_PATH = "./chatterbox_model_files" MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"] try: from chatterbox.tts import ChatterboxTTS chatterbox_available = True print("Chatterbox TTS imported successfully") except ImportError as e: print(f"Failed to import ChatterboxTTS: {e}") chatterbox_available = False # Global model variable - will be loaded inside GPU function model = None model_loaded = False # Text length limits for the model MAX_CHARS_PER_GENERATION = 1000 # Safe limit for single generation MAX_CHARS_TOTAL = 5000 # Maximum we'll accept via API def download_model_files(): """Download model files with error handling.""" print(f"Checking for model files in {LOCAL_MODEL_PATH}...") os.makedirs(LOCAL_MODEL_PATH, exist_ok=True) for filename in MODEL_FILES: local_path = os.path.join(LOCAL_MODEL_PATH, filename) if not os.path.exists(local_path): print(f"Downloading {filename} from {MODEL_REPO_ID}...") try: downloaded_path = hf_hub_download( repo_id=MODEL_REPO_ID, filename=filename, cache_dir="./cache", force_download=False ) shutil.copy2(downloaded_path, local_path) print(f"✓ Downloaded and copied {filename}") except Exception as e: print(f"✗ Failed to download {filename}: {e}") raise e else: print(f"✓ {filename} already exists locally") print("All model files are ready!") def load_model_on_gpu(): """Load model inside GPU context - only called within @spaces.GPU decorated function.""" global model, model_loaded if model_loaded and model is not None: return True if not chatterbox_available: print("ERROR: Chatterbox TTS library not available") return False try: print("Loading model inside GPU context...") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model on device: {device}") # Try different loading methods try: model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device) print("✓ Model loaded successfully using from_local method.") except Exception as e1: print(f"from_local failed: {e1}") try: model = ChatterboxTTS.from_pretrained(device) print("✓ Model loaded successfully with from_pretrained.") except Exception as e2: print(f"from_pretrained failed: {e2}") model = load_model_manually(device) if model and hasattr(model, 'to'): model = model.to(device) if model and hasattr(model, 'eval'): model.eval() model_loaded = True print("✓ Model loaded successfully in GPU context") return True except Exception as e: print(f"ERROR: Failed to load model in GPU context: {e}") traceback.print_exc() model = None model_loaded = False return False def load_model_manually(device): """Manual model loading with proper error handling.""" import pathlib import json model_path = pathlib.Path(LOCAL_MODEL_PATH) print("Manual loading with correct constructor signature...") s3gen_path = model_path / "s3gen.pt" ve_path = model_path / "ve.pt" tokenizer_path = model_path / "tokenizer.json" t3_cfg_path = model_path / "t3_cfg.pt" s3gen = torch.load(s3gen_path, map_location='cpu') ve = torch.load(ve_path, map_location='cpu') t3_cfg = torch.load(t3_cfg_path, map_location='cpu') with open(tokenizer_path, 'r') as f: tokenizer_data = json.load(f) try: from chatterbox.models.tokenizers.tokenizer import EnTokenizer tokenizer = EnTokenizer.from_dict(tokenizer_data) except Exception: tokenizer = tokenizer_data model = ChatterboxTTS( t3=t3_cfg, s3gen=s3gen, ve=ve, tokenizer=tokenizer, device=device ) print("✓ Model loaded successfully with manual constructor.") return model def cleanup_gpu_memory(): """Clean up GPU memory - only call within GPU context.""" try: if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize() gc.collect() except Exception as e: print(f"Warning: GPU cleanup failed: {e}") def truncate_text_safely(text, max_chars=MAX_CHARS_PER_GENERATION): """Truncate text to safe length while preserving sentence boundaries.""" if len(text) <= max_chars: return text, False # Find the last sentence ending before the limit truncated = text[:max_chars] # Look for sentence endings for ending in ['. ', '! ', '? ']: last_sentence = truncated.rfind(ending) if last_sentence > max_chars * 0.7: # Don't truncate too aggressively return text[:last_sentence + 1].strip(), True # Fallback to word boundary last_space = truncated.rfind(' ') if last_space > max_chars * 0.8: return text[:last_space].strip(), True # Last resort: hard truncate return truncated.strip(), True # Download model files during startup (CPU only) if chatterbox_available: try: download_model_files() print("Model files downloaded. Model will be loaded on first GPU request.") except Exception as e: print(f"ERROR during model file download: {e}") @spaces.GPU def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): """Main voice cloning function - runs on GPU.""" global model, model_loaded # Input validation if not chatterbox_available: return None, "Error: Chatterbox TTS library not available. Please check installation." if not text_to_speak or text_to_speak.strip() == "": return None, "Error: Please enter some text to speak." if reference_audio_path is None: return None, "Error: Please upload a reference audio file (.wav or .mp3)." # Check text length and truncate if necessary original_length = len(text_to_speak) if original_length > MAX_CHARS_TOTAL: return None, f"Error: Text is too long ({original_length:,} characters). Maximum allowed is {MAX_CHARS_TOTAL:,} characters. Please use the chunked generation API for longer texts." # Truncate to safe generation length text_to_use, was_truncated = truncate_text_safely(text_to_speak, MAX_CHARS_PER_GENERATION) try: # Load model if not already loaded if not model_loaded: print("Loading model for the first time...") if not load_model_on_gpu(): return None, "Error: Failed to load model. Please check the logs for details." if model is None: return None, "Error: Model not loaded. Please check the logs for details." print(f"Processing request:") print(f" Original text length: {original_length:,} characters") print(f" Processing length: {len(text_to_use):,} characters") print(f" Truncated: {was_truncated}") print(f" Audio: '{reference_audio_path}'") print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}") # Clean GPU memory before generation cleanup_gpu_memory() # Set random seed if specified if random_seed > 0: torch.manual_seed(random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed(random_seed) # Check CUDA availability and memory if torch.cuda.is_available(): print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") # Generate audio with error handling try: with torch.no_grad(): output_wav_data = model.generate( text=text_to_use, audio_prompt_path=reference_audio_path, exaggeration=exaggeration, cfg_weight=cfg_pace, temperature=temperature ) except RuntimeError as e: if "CUDA" in str(e) or "out of memory" in str(e) or "device-side assert" in str(e): print(f"CUDA error during generation: {e}") cleanup_gpu_memory() return None, f"CUDA error: Text may be too long for single generation. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API for longer content." else: raise e # Get sample rate try: sample_rate = model.sr except: sample_rate = 24000 # Process output if isinstance(output_wav_data, str): result = output_wav_data else: import numpy as np if hasattr(output_wav_data, 'cpu'): output_wav_data = output_wav_data.cpu().numpy() if output_wav_data.ndim > 1: output_wav_data = output_wav_data.squeeze() result = (sample_rate, output_wav_data) # Clean up GPU memory after generation cleanup_gpu_memory() if torch.cuda.is_available(): print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") print("✓ Audio generated successfully") # Prepare success message success_msg = "Success: Audio generated successfully!" if was_truncated: success_msg += f" Note: Text was truncated from {original_length:,} to {len(text_to_use):,} characters for optimal generation. Use the chunked generation API for longer texts." return result, success_msg except Exception as e: print(f"ERROR during audio generation: {e}") traceback.print_exc() # Clean up on error try: cleanup_gpu_memory() except: pass # Provide specific error messages error_msg = str(e) if "CUDA" in error_msg or "device-side assert" in error_msg: return None, f"CUDA error: {error_msg}. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API." elif "out of memory" in error_msg: return None, f"GPU memory error: {error_msg}. Please try with shorter text." else: return None, f"Error during audio generation: {error_msg}. Check logs for more details." def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): """API wrapper function.""" import requests import tempfile import os import base64 temp_audio_path = None try: # Handle different audio input formats if reference_audio_url.startswith('data:audio'): header, encoded = reference_audio_url.split(',', 1) audio_data = base64.b64decode(encoded) ext = '.mp3' if 'mp3' in header else '.wav' with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: temp_file.write(audio_data) temp_audio_path = temp_file.name elif reference_audio_url.startswith('http'): response = requests.get(reference_audio_url, timeout=30) response.raise_for_status() ext = '.mp3' if reference_audio_url.endswith('.mp3') else '.wav' with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: temp_file.write(response.content) temp_audio_path = temp_file.name else: temp_audio_path = reference_audio_url # Call the GPU function audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature) return audio_output, status except Exception as e: print(f"API Error: {e}") return None, f"API Error: {str(e)}" finally: # Clean up temporary file if temp_audio_path and temp_audio_path != reference_audio_url: try: os.unlink(temp_audio_path) except: pass def main(): print("Starting Advanced Gradio interface...") with gr.Blocks(title="🎙️ Advanced Chatterbox Voice Cloning") as demo: gr.Markdown("# 🎙️ Advanced Chatterbox Voice Cloning") gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.") # Add warning about text length gr.Markdown(f""" **⚠️ Text Length Limits:** - **Single Generation**: Up to {MAX_CHARS_PER_GENERATION:,} characters (optimal quality) - **API Maximum**: Up to {MAX_CHARS_TOTAL:,} characters (may be truncated) - **For longer texts**: Use the chunked generation API in your application """) with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label=f"Text to Speak (max {MAX_CHARS_TOTAL:,} characters)", placeholder="Enter the text you want the cloned voice to say...", lines=5, max_lines=10 ) audio_input = gr.Audio( type="filepath", label="Reference Audio (Upload a short .wav or .mp3 clip)", sources=["upload", "microphone"] ) with gr.Accordion("🔧 Advanced Settings", open=False): with gr.Row(): exaggeration_input = gr.Slider( minimum=0.25, maximum=1.0, value=0.6, step=0.05, label="Exaggeration", info="Controls voice characteristic emphasis" ) cfg_pace_input = gr.Slider( minimum=0.2, maximum=1.0, value=0.3, step=0.05, label="CFG/Pace", info="Classifier-free guidance weight" ) with gr.Row(): seed_input = gr.Number( value=0, label="Random Seed", info="Set to 0 for random results", precision=0 ) temperature_input = gr.Slider( minimum=0.05, maximum=2.0, value=0.6, step=0.05, label="Temperature", info="Controls randomness in generation" ) generate_btn = gr.Button("🎵 Generate Voice Clone", variant="primary", size="lg") with gr.Column(scale=1): audio_output = gr.Audio(label="Generated Audio", type="numpy") status_output = gr.Textbox(label="Status", lines=3) # Connect the interface generate_btn.click( fn=clone_voice_api, inputs=[text_input, audio_input, exaggeration_input, cfg_pace_input, seed_input, temperature_input], outputs=[audio_output, status_output], api_name="predict" ) # API endpoint for external calls def clone_voice_base64_api(text_to_speak, reference_audio_b64, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): return clone_voice_api(text_to_speak, reference_audio_b64, exaggeration, cfg_pace, random_seed, temperature) # Hidden API interface with gr.Row(visible=False): api_text_input = gr.Textbox() api_audio_input = gr.Textbox() api_exaggeration_input = gr.Slider(minimum=0.25, maximum=1.0, value=0.6) api_cfg_pace_input = gr.Slider(minimum=0.2, maximum=1.0, value=0.3) api_seed_input = gr.Number(value=0, precision=0) api_temperature_input = gr.Slider(minimum=0.05, maximum=2.0, value=0.6) api_audio_output = gr.Audio(type="numpy") api_status_output = gr.Textbox() api_btn = gr.Button() api_btn.click( fn=clone_voice_base64_api, inputs=[api_text_input, api_audio_input, api_exaggeration_input, api_cfg_pace_input, api_seed_input, api_temperature_input], outputs=[api_audio_output, api_status_output], api_name="clone_voice" ) demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, quiet=False, share=False ) if __name__ == "__main__": main()