Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import os | |
import traceback | |
import torch | |
import gc | |
from huggingface_hub import hf_hub_download | |
import shutil | |
import spaces | |
try: | |
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH | |
except ImportError: | |
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model" | |
LOCAL_MODEL_PATH = "./chatterbox_model_files" | |
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"] | |
try: | |
from chatterbox.tts import ChatterboxTTS | |
chatterbox_available = True | |
print("Chatterbox TTS imported successfully") | |
except ImportError as e: | |
print(f"Failed to import ChatterboxTTS: {e}") | |
chatterbox_available = False | |
# Global model variable - will be loaded inside GPU function | |
model = None | |
model_loaded = False | |
# Text length limits for the model | |
MAX_CHARS_PER_GENERATION = 1000 # Safe limit for single generation | |
MAX_CHARS_TOTAL = 5000 # Maximum we'll accept via API | |
def download_model_files(): | |
"""Download model files with error handling.""" | |
print(f"Checking for model files in {LOCAL_MODEL_PATH}...") | |
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True) | |
for filename in MODEL_FILES: | |
local_path = os.path.join(LOCAL_MODEL_PATH, filename) | |
if not os.path.exists(local_path): | |
print(f"Downloading {filename} from {MODEL_REPO_ID}...") | |
try: | |
downloaded_path = hf_hub_download( | |
repo_id=MODEL_REPO_ID, | |
filename=filename, | |
cache_dir="./cache", | |
force_download=False | |
) | |
shutil.copy2(downloaded_path, local_path) | |
print(f"β Downloaded and copied {filename}") | |
except Exception as e: | |
print(f"β Failed to download {filename}: {e}") | |
raise e | |
else: | |
print(f"β {filename} already exists locally") | |
print("All model files are ready!") | |
def load_model_on_gpu(): | |
"""Load model inside GPU context - only called within @spaces.GPU decorated function.""" | |
global model, model_loaded | |
if model_loaded and model is not None: | |
return True | |
if not chatterbox_available: | |
print("ERROR: Chatterbox TTS library not available") | |
return False | |
try: | |
print("Loading model inside GPU context...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Loading model on device: {device}") | |
# Try different loading methods | |
try: | |
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device) | |
print("β Model loaded successfully using from_local method.") | |
except Exception as e1: | |
print(f"from_local failed: {e1}") | |
try: | |
model = ChatterboxTTS.from_pretrained(device) | |
print("β Model loaded successfully with from_pretrained.") | |
except Exception as e2: | |
print(f"from_pretrained failed: {e2}") | |
model = load_model_manually(device) | |
if model and hasattr(model, 'to'): | |
model = model.to(device) | |
if model and hasattr(model, 'eval'): | |
model.eval() | |
model_loaded = True | |
print("β Model loaded successfully in GPU context") | |
return True | |
except Exception as e: | |
print(f"ERROR: Failed to load model in GPU context: {e}") | |
traceback.print_exc() | |
model = None | |
model_loaded = False | |
return False | |
def load_model_manually(device): | |
"""Manual model loading with proper error handling.""" | |
import pathlib | |
import json | |
model_path = pathlib.Path(LOCAL_MODEL_PATH) | |
print("Manual loading with correct constructor signature...") | |
s3gen_path = model_path / "s3gen.pt" | |
ve_path = model_path / "ve.pt" | |
tokenizer_path = model_path / "tokenizer.json" | |
t3_cfg_path = model_path / "t3_cfg.pt" | |
s3gen = torch.load(s3gen_path, map_location='cpu') | |
ve = torch.load(ve_path, map_location='cpu') | |
t3_cfg = torch.load(t3_cfg_path, map_location='cpu') | |
with open(tokenizer_path, 'r') as f: | |
tokenizer_data = json.load(f) | |
try: | |
from chatterbox.models.tokenizers.tokenizer import EnTokenizer | |
tokenizer = EnTokenizer.from_dict(tokenizer_data) | |
except Exception: | |
tokenizer = tokenizer_data | |
model = ChatterboxTTS( | |
t3=t3_cfg, | |
s3gen=s3gen, | |
ve=ve, | |
tokenizer=tokenizer, | |
device=device | |
) | |
print("β Model loaded successfully with manual constructor.") | |
return model | |
def cleanup_gpu_memory(): | |
"""Clean up GPU memory - only call within GPU context.""" | |
try: | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
gc.collect() | |
except Exception as e: | |
print(f"Warning: GPU cleanup failed: {e}") | |
def truncate_text_safely(text, max_chars=MAX_CHARS_PER_GENERATION): | |
"""Truncate text to safe length while preserving sentence boundaries.""" | |
if len(text) <= max_chars: | |
return text, False | |
# Find the last sentence ending before the limit | |
truncated = text[:max_chars] | |
# Look for sentence endings | |
for ending in ['. ', '! ', '? ']: | |
last_sentence = truncated.rfind(ending) | |
if last_sentence > max_chars * 0.7: # Don't truncate too aggressively | |
return text[:last_sentence + 1].strip(), True | |
# Fallback to word boundary | |
last_space = truncated.rfind(' ') | |
if last_space > max_chars * 0.8: | |
return text[:last_space].strip(), True | |
# Last resort: hard truncate | |
return truncated.strip(), True | |
# Download model files during startup (CPU only) | |
if chatterbox_available: | |
try: | |
download_model_files() | |
print("Model files downloaded. Model will be loaded on first GPU request.") | |
except Exception as e: | |
print(f"ERROR during model file download: {e}") | |
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
"""Main voice cloning function - runs on GPU.""" | |
global model, model_loaded | |
# Input validation | |
if not chatterbox_available: | |
return None, "Error: Chatterbox TTS library not available. Please check installation." | |
if not text_to_speak or text_to_speak.strip() == "": | |
return None, "Error: Please enter some text to speak." | |
if reference_audio_path is None: | |
return None, "Error: Please upload a reference audio file (.wav or .mp3)." | |
# Check text length and truncate if necessary | |
original_length = len(text_to_speak) | |
if original_length > MAX_CHARS_TOTAL: | |
return None, f"Error: Text is too long ({original_length:,} characters). Maximum allowed is {MAX_CHARS_TOTAL:,} characters. Please use the chunked generation API for longer texts." | |
# Truncate to safe generation length | |
text_to_use, was_truncated = truncate_text_safely(text_to_speak, MAX_CHARS_PER_GENERATION) | |
try: | |
# Load model if not already loaded | |
if not model_loaded: | |
print("Loading model for the first time...") | |
if not load_model_on_gpu(): | |
return None, "Error: Failed to load model. Please check the logs for details." | |
if model is None: | |
return None, "Error: Model not loaded. Please check the logs for details." | |
print(f"Processing request:") | |
print(f" Original text length: {original_length:,} characters") | |
print(f" Processing length: {len(text_to_use):,} characters") | |
print(f" Truncated: {was_truncated}") | |
print(f" Audio: '{reference_audio_path}'") | |
print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}") | |
# Clean GPU memory before generation | |
cleanup_gpu_memory() | |
# Set random seed if specified | |
if random_seed > 0: | |
torch.manual_seed(random_seed) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(random_seed) | |
# Check CUDA availability and memory | |
if torch.cuda.is_available(): | |
print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") | |
# Generate audio with error handling | |
try: | |
with torch.no_grad(): | |
output_wav_data = model.generate( | |
text=text_to_use, | |
audio_prompt_path=reference_audio_path, | |
exaggeration=exaggeration, | |
cfg_weight=cfg_pace, | |
temperature=temperature | |
) | |
except RuntimeError as e: | |
if "CUDA" in str(e) or "out of memory" in str(e) or "device-side assert" in str(e): | |
print(f"CUDA error during generation: {e}") | |
cleanup_gpu_memory() | |
return None, f"CUDA error: Text may be too long for single generation. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API for longer content." | |
else: | |
raise e | |
# Get sample rate | |
try: | |
sample_rate = model.sr | |
except: | |
sample_rate = 24000 | |
# Process output | |
if isinstance(output_wav_data, str): | |
result = output_wav_data | |
else: | |
import numpy as np | |
if hasattr(output_wav_data, 'cpu'): | |
output_wav_data = output_wav_data.cpu().numpy() | |
if output_wav_data.ndim > 1: | |
output_wav_data = output_wav_data.squeeze() | |
result = (sample_rate, output_wav_data) | |
# Clean up GPU memory after generation | |
cleanup_gpu_memory() | |
if torch.cuda.is_available(): | |
print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB") | |
print("β Audio generated successfully") | |
# Prepare success message | |
success_msg = "Success: Audio generated successfully!" | |
if was_truncated: | |
success_msg += f" Note: Text was truncated from {original_length:,} to {len(text_to_use):,} characters for optimal generation. Use the chunked generation API for longer texts." | |
return result, success_msg | |
except Exception as e: | |
print(f"ERROR during audio generation: {e}") | |
traceback.print_exc() | |
# Clean up on error | |
try: | |
cleanup_gpu_memory() | |
except: | |
pass | |
# Provide specific error messages | |
error_msg = str(e) | |
if "CUDA" in error_msg or "device-side assert" in error_msg: | |
return None, f"CUDA error: {error_msg}. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API." | |
elif "out of memory" in error_msg: | |
return None, f"GPU memory error: {error_msg}. Please try with shorter text." | |
else: | |
return None, f"Error during audio generation: {error_msg}. Check logs for more details." | |
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
"""API wrapper function.""" | |
import requests | |
import tempfile | |
import os | |
import base64 | |
temp_audio_path = None | |
try: | |
# Handle different audio input formats | |
if reference_audio_url.startswith('data:audio'): | |
header, encoded = reference_audio_url.split(',', 1) | |
audio_data = base64.b64decode(encoded) | |
ext = '.mp3' if 'mp3' in header else '.wav' | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(audio_data) | |
temp_audio_path = temp_file.name | |
elif reference_audio_url.startswith('http'): | |
response = requests.get(reference_audio_url, timeout=30) | |
response.raise_for_status() | |
ext = '.mp3' if reference_audio_url.endswith('.mp3') else '.wav' | |
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file: | |
temp_file.write(response.content) | |
temp_audio_path = temp_file.name | |
else: | |
temp_audio_path = reference_audio_url | |
# Call the GPU function | |
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature) | |
return audio_output, status | |
except Exception as e: | |
print(f"API Error: {e}") | |
return None, f"API Error: {str(e)}" | |
finally: | |
# Clean up temporary file | |
if temp_audio_path and temp_audio_path != reference_audio_url: | |
try: | |
os.unlink(temp_audio_path) | |
except: | |
pass | |
def main(): | |
print("Starting Advanced Gradio interface...") | |
with gr.Blocks(title="ποΈ Advanced Chatterbox Voice Cloning") as demo: | |
gr.Markdown("# ποΈ Advanced Chatterbox Voice Cloning") | |
gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.") | |
# Add warning about text length | |
gr.Markdown(f""" | |
**β οΈ Text Length Limits:** | |
- **Single Generation**: Up to {MAX_CHARS_PER_GENERATION:,} characters (optimal quality) | |
- **API Maximum**: Up to {MAX_CHARS_TOTAL:,} characters (may be truncated) | |
- **For longer texts**: Use the chunked generation API in your application | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
text_input = gr.Textbox( | |
label=f"Text to Speak (max {MAX_CHARS_TOTAL:,} characters)", | |
placeholder="Enter the text you want the cloned voice to say...", | |
lines=5, | |
max_lines=10 | |
) | |
audio_input = gr.Audio( | |
type="filepath", | |
label="Reference Audio (Upload a short .wav or .mp3 clip)", | |
sources=["upload", "microphone"] | |
) | |
with gr.Accordion("π§ Advanced Settings", open=False): | |
with gr.Row(): | |
exaggeration_input = gr.Slider( | |
minimum=0.25, maximum=1.0, value=0.6, step=0.05, | |
label="Exaggeration", info="Controls voice characteristic emphasis" | |
) | |
cfg_pace_input = gr.Slider( | |
minimum=0.2, maximum=1.0, value=0.3, step=0.05, | |
label="CFG/Pace", info="Classifier-free guidance weight" | |
) | |
with gr.Row(): | |
seed_input = gr.Number( | |
value=0, label="Random Seed", info="Set to 0 for random results", precision=0 | |
) | |
temperature_input = gr.Slider( | |
minimum=0.05, maximum=2.0, value=0.6, step=0.05, | |
label="Temperature", info="Controls randomness in generation" | |
) | |
generate_btn = gr.Button("π΅ Generate Voice Clone", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
audio_output = gr.Audio(label="Generated Audio", type="numpy") | |
status_output = gr.Textbox(label="Status", lines=3) | |
# Connect the interface | |
generate_btn.click( | |
fn=clone_voice_api, | |
inputs=[text_input, audio_input, exaggeration_input, cfg_pace_input, seed_input, temperature_input], | |
outputs=[audio_output, status_output], | |
api_name="predict" | |
) | |
# API endpoint for external calls | |
def clone_voice_base64_api(text_to_speak, reference_audio_b64, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6): | |
return clone_voice_api(text_to_speak, reference_audio_b64, exaggeration, cfg_pace, random_seed, temperature) | |
# Hidden API interface | |
with gr.Row(visible=False): | |
api_text_input = gr.Textbox() | |
api_audio_input = gr.Textbox() | |
api_exaggeration_input = gr.Slider(minimum=0.25, maximum=1.0, value=0.6) | |
api_cfg_pace_input = gr.Slider(minimum=0.2, maximum=1.0, value=0.3) | |
api_seed_input = gr.Number(value=0, precision=0) | |
api_temperature_input = gr.Slider(minimum=0.05, maximum=2.0, value=0.6) | |
api_audio_output = gr.Audio(type="numpy") | |
api_status_output = gr.Textbox() | |
api_btn = gr.Button() | |
api_btn.click( | |
fn=clone_voice_base64_api, | |
inputs=[api_text_input, api_audio_input, api_exaggeration_input, api_cfg_pace_input, api_seed_input, api_temperature_input], | |
outputs=[api_audio_output, api_status_output], | |
api_name="clone_voice" | |
) | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True, | |
quiet=False, | |
share=False | |
) | |
if __name__ == "__main__": | |
main() | |