voice_cloning / app.py
ramimu's picture
Update app.py
ee39abc verified
import gradio as gr
import os
import traceback
import torch
import gc
from huggingface_hub import hf_hub_download
import shutil
import spaces
try:
from config import MODEL_REPO_ID, MODEL_FILES, LOCAL_MODEL_PATH
except ImportError:
MODEL_REPO_ID = "ramimu/chatterbox-voice-cloning-model"
LOCAL_MODEL_PATH = "./chatterbox_model_files"
MODEL_FILES = ["s3gen.pt", "t3_cfg.pt", "ve.pt", "tokenizer.json"]
try:
from chatterbox.tts import ChatterboxTTS
chatterbox_available = True
print("Chatterbox TTS imported successfully")
except ImportError as e:
print(f"Failed to import ChatterboxTTS: {e}")
chatterbox_available = False
# Global model variable - will be loaded inside GPU function
model = None
model_loaded = False
# Text length limits for the model
MAX_CHARS_PER_GENERATION = 1000 # Safe limit for single generation
MAX_CHARS_TOTAL = 5000 # Maximum we'll accept via API
def download_model_files():
"""Download model files with error handling."""
print(f"Checking for model files in {LOCAL_MODEL_PATH}...")
os.makedirs(LOCAL_MODEL_PATH, exist_ok=True)
for filename in MODEL_FILES:
local_path = os.path.join(LOCAL_MODEL_PATH, filename)
if not os.path.exists(local_path):
print(f"Downloading {filename} from {MODEL_REPO_ID}...")
try:
downloaded_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=filename,
cache_dir="./cache",
force_download=False
)
shutil.copy2(downloaded_path, local_path)
print(f"βœ“ Downloaded and copied {filename}")
except Exception as e:
print(f"βœ— Failed to download {filename}: {e}")
raise e
else:
print(f"βœ“ {filename} already exists locally")
print("All model files are ready!")
def load_model_on_gpu():
"""Load model inside GPU context - only called within @spaces.GPU decorated function."""
global model, model_loaded
if model_loaded and model is not None:
return True
if not chatterbox_available:
print("ERROR: Chatterbox TTS library not available")
return False
try:
print("Loading model inside GPU context...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model on device: {device}")
# Try different loading methods
try:
model = ChatterboxTTS.from_local(LOCAL_MODEL_PATH, device)
print("βœ“ Model loaded successfully using from_local method.")
except Exception as e1:
print(f"from_local failed: {e1}")
try:
model = ChatterboxTTS.from_pretrained(device)
print("βœ“ Model loaded successfully with from_pretrained.")
except Exception as e2:
print(f"from_pretrained failed: {e2}")
model = load_model_manually(device)
if model and hasattr(model, 'to'):
model = model.to(device)
if model and hasattr(model, 'eval'):
model.eval()
model_loaded = True
print("βœ“ Model loaded successfully in GPU context")
return True
except Exception as e:
print(f"ERROR: Failed to load model in GPU context: {e}")
traceback.print_exc()
model = None
model_loaded = False
return False
def load_model_manually(device):
"""Manual model loading with proper error handling."""
import pathlib
import json
model_path = pathlib.Path(LOCAL_MODEL_PATH)
print("Manual loading with correct constructor signature...")
s3gen_path = model_path / "s3gen.pt"
ve_path = model_path / "ve.pt"
tokenizer_path = model_path / "tokenizer.json"
t3_cfg_path = model_path / "t3_cfg.pt"
s3gen = torch.load(s3gen_path, map_location='cpu')
ve = torch.load(ve_path, map_location='cpu')
t3_cfg = torch.load(t3_cfg_path, map_location='cpu')
with open(tokenizer_path, 'r') as f:
tokenizer_data = json.load(f)
try:
from chatterbox.models.tokenizers.tokenizer import EnTokenizer
tokenizer = EnTokenizer.from_dict(tokenizer_data)
except Exception:
tokenizer = tokenizer_data
model = ChatterboxTTS(
t3=t3_cfg,
s3gen=s3gen,
ve=ve,
tokenizer=tokenizer,
device=device
)
print("βœ“ Model loaded successfully with manual constructor.")
return model
def cleanup_gpu_memory():
"""Clean up GPU memory - only call within GPU context."""
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
gc.collect()
except Exception as e:
print(f"Warning: GPU cleanup failed: {e}")
def truncate_text_safely(text, max_chars=MAX_CHARS_PER_GENERATION):
"""Truncate text to safe length while preserving sentence boundaries."""
if len(text) <= max_chars:
return text, False
# Find the last sentence ending before the limit
truncated = text[:max_chars]
# Look for sentence endings
for ending in ['. ', '! ', '? ']:
last_sentence = truncated.rfind(ending)
if last_sentence > max_chars * 0.7: # Don't truncate too aggressively
return text[:last_sentence + 1].strip(), True
# Fallback to word boundary
last_space = truncated.rfind(' ')
if last_space > max_chars * 0.8:
return text[:last_space].strip(), True
# Last resort: hard truncate
return truncated.strip(), True
# Download model files during startup (CPU only)
if chatterbox_available:
try:
download_model_files()
print("Model files downloaded. Model will be loaded on first GPU request.")
except Exception as e:
print(f"ERROR during model file download: {e}")
@spaces.GPU
def clone_voice(text_to_speak, reference_audio_path, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
"""Main voice cloning function - runs on GPU."""
global model, model_loaded
# Input validation
if not chatterbox_available:
return None, "Error: Chatterbox TTS library not available. Please check installation."
if not text_to_speak or text_to_speak.strip() == "":
return None, "Error: Please enter some text to speak."
if reference_audio_path is None:
return None, "Error: Please upload a reference audio file (.wav or .mp3)."
# Check text length and truncate if necessary
original_length = len(text_to_speak)
if original_length > MAX_CHARS_TOTAL:
return None, f"Error: Text is too long ({original_length:,} characters). Maximum allowed is {MAX_CHARS_TOTAL:,} characters. Please use the chunked generation API for longer texts."
# Truncate to safe generation length
text_to_use, was_truncated = truncate_text_safely(text_to_speak, MAX_CHARS_PER_GENERATION)
try:
# Load model if not already loaded
if not model_loaded:
print("Loading model for the first time...")
if not load_model_on_gpu():
return None, "Error: Failed to load model. Please check the logs for details."
if model is None:
return None, "Error: Model not loaded. Please check the logs for details."
print(f"Processing request:")
print(f" Original text length: {original_length:,} characters")
print(f" Processing length: {len(text_to_use):,} characters")
print(f" Truncated: {was_truncated}")
print(f" Audio: '{reference_audio_path}'")
print(f" Parameters: exag={exaggeration}, cfg={cfg_pace}, seed={random_seed}, temp={temperature}")
# Clean GPU memory before generation
cleanup_gpu_memory()
# Set random seed if specified
if random_seed > 0:
torch.manual_seed(random_seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(random_seed)
# Check CUDA availability and memory
if torch.cuda.is_available():
print(f"CUDA memory before generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
# Generate audio with error handling
try:
with torch.no_grad():
output_wav_data = model.generate(
text=text_to_use,
audio_prompt_path=reference_audio_path,
exaggeration=exaggeration,
cfg_weight=cfg_pace,
temperature=temperature
)
except RuntimeError as e:
if "CUDA" in str(e) or "out of memory" in str(e) or "device-side assert" in str(e):
print(f"CUDA error during generation: {e}")
cleanup_gpu_memory()
return None, f"CUDA error: Text may be too long for single generation. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API for longer content."
else:
raise e
# Get sample rate
try:
sample_rate = model.sr
except:
sample_rate = 24000
# Process output
if isinstance(output_wav_data, str):
result = output_wav_data
else:
import numpy as np
if hasattr(output_wav_data, 'cpu'):
output_wav_data = output_wav_data.cpu().numpy()
if output_wav_data.ndim > 1:
output_wav_data = output_wav_data.squeeze()
result = (sample_rate, output_wav_data)
# Clean up GPU memory after generation
cleanup_gpu_memory()
if torch.cuda.is_available():
print(f"CUDA memory after generation: {torch.cuda.memory_allocated() / 1024**2:.1f} MB")
print("βœ“ Audio generated successfully")
# Prepare success message
success_msg = "Success: Audio generated successfully!"
if was_truncated:
success_msg += f" Note: Text was truncated from {original_length:,} to {len(text_to_use):,} characters for optimal generation. Use the chunked generation API for longer texts."
return result, success_msg
except Exception as e:
print(f"ERROR during audio generation: {e}")
traceback.print_exc()
# Clean up on error
try:
cleanup_gpu_memory()
except:
pass
# Provide specific error messages
error_msg = str(e)
if "CUDA" in error_msg or "device-side assert" in error_msg:
return None, f"CUDA error: {error_msg}. Try shorter text (under {MAX_CHARS_PER_GENERATION} characters) or use the chunked generation API."
elif "out of memory" in error_msg:
return None, f"GPU memory error: {error_msg}. Please try with shorter text."
else:
return None, f"Error during audio generation: {error_msg}. Check logs for more details."
def clone_voice_api(text_to_speak, reference_audio_url, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
"""API wrapper function."""
import requests
import tempfile
import os
import base64
temp_audio_path = None
try:
# Handle different audio input formats
if reference_audio_url.startswith('data:audio'):
header, encoded = reference_audio_url.split(',', 1)
audio_data = base64.b64decode(encoded)
ext = '.mp3' if 'mp3' in header else '.wav'
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(audio_data)
temp_audio_path = temp_file.name
elif reference_audio_url.startswith('http'):
response = requests.get(reference_audio_url, timeout=30)
response.raise_for_status()
ext = '.mp3' if reference_audio_url.endswith('.mp3') else '.wav'
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as temp_file:
temp_file.write(response.content)
temp_audio_path = temp_file.name
else:
temp_audio_path = reference_audio_url
# Call the GPU function
audio_output, status = clone_voice(text_to_speak, temp_audio_path, exaggeration, cfg_pace, random_seed, temperature)
return audio_output, status
except Exception as e:
print(f"API Error: {e}")
return None, f"API Error: {str(e)}"
finally:
# Clean up temporary file
if temp_audio_path and temp_audio_path != reference_audio_url:
try:
os.unlink(temp_audio_path)
except:
pass
def main():
print("Starting Advanced Gradio interface...")
with gr.Blocks(title="πŸŽ™οΈ Advanced Chatterbox Voice Cloning") as demo:
gr.Markdown("# πŸŽ™οΈ Advanced Chatterbox Voice Cloning")
gr.Markdown("Clone any voice using advanced AI technology with fine-tuned controls.")
# Add warning about text length
gr.Markdown(f"""
**⚠️ Text Length Limits:**
- **Single Generation**: Up to {MAX_CHARS_PER_GENERATION:,} characters (optimal quality)
- **API Maximum**: Up to {MAX_CHARS_TOTAL:,} characters (may be truncated)
- **For longer texts**: Use the chunked generation API in your application
""")
with gr.Row():
with gr.Column(scale=2):
text_input = gr.Textbox(
label=f"Text to Speak (max {MAX_CHARS_TOTAL:,} characters)",
placeholder="Enter the text you want the cloned voice to say...",
lines=5,
max_lines=10
)
audio_input = gr.Audio(
type="filepath",
label="Reference Audio (Upload a short .wav or .mp3 clip)",
sources=["upload", "microphone"]
)
with gr.Accordion("πŸ”§ Advanced Settings", open=False):
with gr.Row():
exaggeration_input = gr.Slider(
minimum=0.25, maximum=1.0, value=0.6, step=0.05,
label="Exaggeration", info="Controls voice characteristic emphasis"
)
cfg_pace_input = gr.Slider(
minimum=0.2, maximum=1.0, value=0.3, step=0.05,
label="CFG/Pace", info="Classifier-free guidance weight"
)
with gr.Row():
seed_input = gr.Number(
value=0, label="Random Seed", info="Set to 0 for random results", precision=0
)
temperature_input = gr.Slider(
minimum=0.05, maximum=2.0, value=0.6, step=0.05,
label="Temperature", info="Controls randomness in generation"
)
generate_btn = gr.Button("🎡 Generate Voice Clone", variant="primary", size="lg")
with gr.Column(scale=1):
audio_output = gr.Audio(label="Generated Audio", type="numpy")
status_output = gr.Textbox(label="Status", lines=3)
# Connect the interface
generate_btn.click(
fn=clone_voice_api,
inputs=[text_input, audio_input, exaggeration_input, cfg_pace_input, seed_input, temperature_input],
outputs=[audio_output, status_output],
api_name="predict"
)
# API endpoint for external calls
def clone_voice_base64_api(text_to_speak, reference_audio_b64, exaggeration=0.6, cfg_pace=0.3, random_seed=0, temperature=0.6):
return clone_voice_api(text_to_speak, reference_audio_b64, exaggeration, cfg_pace, random_seed, temperature)
# Hidden API interface
with gr.Row(visible=False):
api_text_input = gr.Textbox()
api_audio_input = gr.Textbox()
api_exaggeration_input = gr.Slider(minimum=0.25, maximum=1.0, value=0.6)
api_cfg_pace_input = gr.Slider(minimum=0.2, maximum=1.0, value=0.3)
api_seed_input = gr.Number(value=0, precision=0)
api_temperature_input = gr.Slider(minimum=0.05, maximum=2.0, value=0.6)
api_audio_output = gr.Audio(type="numpy")
api_status_output = gr.Textbox()
api_btn = gr.Button()
api_btn.click(
fn=clone_voice_base64_api,
inputs=[api_text_input, api_audio_input, api_exaggeration_input, api_cfg_pace_input, api_seed_input, api_temperature_input],
outputs=[api_audio_output, api_status_output],
api_name="clone_voice"
)
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
quiet=False,
share=False
)
if __name__ == "__main__":
main()