import random import numpy as np import torch from chatterbox.src.chatterbox.tts import ChatterboxTTS import gradio as gr import spaces DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"🚀 Running on device: {DEVICE}") # --- Global Model Initialization --- MODEL = None def get_or_load_model(): """Loads the ChatterboxTTS model if it hasn't been loaded already, and ensures it's on the correct device.""" global MODEL if MODEL is None: print("Model not loaded, initializing...") try: MODEL = ChatterboxTTS.from_pretrained(DEVICE) if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE: MODEL.to(DEVICE) print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}") except Exception as e: print(f"Error loading model: {e}") raise return MODEL # Attempt to load the model at startup. try: get_or_load_model() except Exception as e: print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}") def set_seed(seed: int): """Sets the random seed for reproducibility across torch, numpy, and random.""" torch.manual_seed(seed) if DEVICE == "cuda": torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) @spaces.GPU def generate_tts_audio( text_input: str, audio_prompt_path_input: str = None, exaggeration_input: float = 0.5, temperature_input: float = 0.8, seed_num_input: int = 0, cfgw_input: float = 0.5 ) -> tuple[int, np.ndarray]: """ Generate high-quality speech audio from text using ChatterboxTTS model with optional reference audio styling. This tool synthesizes natural-sounding speech from input text. When a reference audio file is provided, it captures the speaker's voice characteristics and speaking style. The generated audio maintains the prosody, tone, and vocal qualities of the reference speaker, or uses default voice if no reference is provided. Args: text_input (str): The text to synthesize into speech (maximum 300 characters) audio_prompt_path_input (str, optional): File path or URL to the reference audio file that defines the target voice style. Defaults to None. exaggeration_input (float, optional): Controls speech expressiveness (0.25-2.0, neutral=0.5, extreme values may be unstable). Defaults to 0.5. temperature_input (float, optional): Controls randomness in generation (0.05-5.0, higher=more varied). Defaults to 0.8. seed_num_input (int, optional): Random seed for reproducible results (0 for random generation). Defaults to 0. cfgw_input (float, optional): CFG/Pace weight controlling generation guidance (0.2-1.0). Defaults to 0.5. Returns: tuple[int, np.ndarray]: A tuple containing the sample rate (int) and the generated audio waveform (numpy.ndarray) """ current_model = get_or_load_model() if current_model is None: raise RuntimeError("TTS model is not loaded.") if seed_num_input != 0: set_seed(int(seed_num_input)) print(f"Generating audio for text: '{text_input[:50]}...'") # Handle optional audio prompt generate_kwargs = { "exaggeration": exaggeration_input, "temperature": temperature_input, "cfg_weight": cfgw_input, } if audio_prompt_path_input: generate_kwargs["audio_prompt_path"] = audio_prompt_path_input wav = current_model.generate( text_input[:300], # Truncate text to max chars **generate_kwargs ) print("Audio generation complete.") return (current_model.sr, wav.squeeze(0).numpy()) with gr.Blocks() as demo: gr.Markdown( """ # Chatterbox TTS Demo Generate high-quality speech from text with reference audio styling. """ ) with gr.Row(): with gr.Column(): text = gr.Textbox( value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.", label="Text to synthesize (max chars 300)", max_lines=5 ) ref_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", label="Reference Audio File (Optional)", value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" ) exaggeration = gr.Slider( 0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5 ) cfg_weight = gr.Slider( 0.2, 1, step=.05, label="CFG/Pace", value=0.5 ) with gr.Accordion("More options", open=False): seed_num = gr.Number(value=0, label="Random seed (0 for random)") temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8) run_btn = gr.Button("Generate", variant="primary") with gr.Column(): audio_output = gr.Audio(label="Output Audio") run_btn.click( fn=generate_tts_audio, inputs=[ text, ref_wav, exaggeration, temp, seed_num, cfg_weight, ], outputs=[audio_output], ) demo.launch(mcp_server=True)