File size: 4,995 Bytes
b7ef483
9d593b2
 
0b3e025
8d3d225
 
 
 
 
 
9d593b2
 
8d3d225
 
0b3e025
 
 
 
 
 
9386371
 
0b3e025
 
9386371
0b3e025
 
8d3d225
 
 
 
9386371
0b3e025
9386371
0b3e025
9386371
0b3e025
 
 
 
 
 
 
9386371
9d593b2
 
9386371
9d593b2
0b3e025
3dab9c0
 
9d593b2
 
 
9386371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b3e025
 
9386371
0b3e025
 
 
 
9386371
0b3e025
9386371
0b3e025
 
 
 
9d593b2
3dab9c0
0b3e025
9d593b2
 
9386371
 
b5cf7f2
9386371
 
 
9d593b2
 
9386371
b5cf7f2
9386371
 
 
 
 
 
 
bf4bbc3
9386371
 
 
 
 
b5cf7f2
9386371
58ffee2
9d593b2
 
b5cf7f2
9d593b2
 
 
 
 
 
 
9386371
9d593b2
 
 
 
 
 
58ffee2
9d593b2
9386371
9d593b2
 
9386371
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os 
import random
import torch
import spaces
import numpy as np
import gradio as gr
from chatterbox.src.chatterbox.tts import ChatterboxTTS
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_REPO = "SebastianBodza/Kartoffelbox-v0.1" 
T3_CHECKPOINT_FILE = "t3_kartoffelbox.safetensors"
print(f"🚀 Running on device: {DEVICE}")

# --- Global Model Initialization ---
MODEL = None

def get_or_load_model():
    """Loads the ChatterboxTTS model if it hasn't been loaded already,
    and ensures it's on the correct device."""
    global MODEL
    if MODEL is None:
        print("Model not loaded, initializing...")
        try:
            MODEL = ChatterboxTTS.from_pretrained(DEVICE)
            checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=T3_CHECKPOINT_FILE, token=os.environ["HUGGING_FACE_HUB_TOKEN"])
            t3_state = load_file(checkpoint_path, device="cpu") 
            MODEL.t3.load_state_dict(t3_state)

            if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
                MODEL.to(DEVICE)
            print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
        except Exception as e:
            print(f"Error loading model: {e}")
            raise
    return MODEL

# Attempt to load the model at startup.
try:
    get_or_load_model()
except Exception as e:
    print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")

def set_seed(seed: int):
    """Sets the random seed for reproducibility across torch, numpy, and random."""
    torch.manual_seed(seed)
    if DEVICE == "cuda":
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

@spaces.GPU
def generate_tts_audio(
    text_input: str,
    audio_prompt_path_input: str,
    exaggeration_input: float,
    temperature_input: float,
    seed_num_input: int,
    cfgw_input: float
) -> tuple[int, np.ndarray]:
    """
    Generates TTS audio using the ChatterboxTTS model.

    Args:
        text_input: The text to synthesize (max 300 characters).
        audio_prompt_path_input: Path to the reference audio file.
        exaggeration_input: Exaggeration parameter for the model.
        temperature_input: Temperature parameter for the model.
        seed_num_input: Random seed (0 for random).
        cfgw_input: CFG/Pace weight.

    Returns:
        A tuple containing the sample rate (int) and the audio waveform (numpy.ndarray).
    """
    current_model = get_or_load_model()

    if current_model is None:
        raise RuntimeError("TTS model is not loaded.")

    if seed_num_input != 0:
        set_seed(int(seed_num_input))

    print(f"Generating audio for text: '{text_input[:50]}...'")
    wav = current_model.generate(
        text_input[:300],  # Truncate text to max chars
        audio_prompt_path=audio_prompt_path_input,
        exaggeration=exaggeration_input,
        temperature=temperature_input,
        cfg_weight=cfgw_input,
    )
    print("Audio generation complete.")
    return (current_model.sr, wav.squeeze(0).numpy())

with gr.Blocks() as demo:
    gr.Markdown(
        """
        # Kartoffel-TTS (Based on Chatterbox) - German Text-to-Speech Demo
        Generate high-quality speech from text with reference audio styling.
        """
    )
    with gr.Row():
        with gr.Column():
            text = gr.Textbox(
                value="Tief im verwunschenen Wald, wo die Bäume uralte Geheimnisse flüsterten, lebte ein kleiner Gnom namens Fips, der die Sprache der Tiere verstand.",
                label="Text to synthesize (max chars 300)",
                max_lines=5
            )
            ref_wav = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Reference Audio File (Optional)",
                value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
            )
            exaggeration = gr.Slider(
                0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5
            )
            cfg_weight = gr.Slider(
                0.2, 1, step=.05, label="CFG/Pace", value=0.3
            )

            with gr.Accordion("More options", open=False):
                seed_num = gr.Number(value=0, label="Random seed (0 for random)")
                temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.6)

            run_btn = gr.Button("Generate", variant="primary")

        with gr.Column():
            audio_output = gr.Audio(label="Output Audio")

    run_btn.click(
        fn=generate_tts_audio,
        inputs=[
            text,
            ref_wav,
            exaggeration,
            temp,
            seed_num,
            cfg_weight,
        ],
        outputs=[audio_output],
    )

demo.launch()