Spaces:
Running
Running
import os | |
import json | |
import random | |
import torch | |
import numpy as np | |
import gradio as gr | |
from chatterbox.tts import ChatterboxTTS | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from torch import nn | |
import re | |
# === Einstellungen === | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MODEL_REPO = "SebastianBodza/Kartoffelbox-v0.1" | |
T3_CHECKPOINT_FILE = "t3_kartoffelbox.safetensors" | |
MAX_CHARS = 5000 | |
CHUNK_CHAR_LIMIT = 300 | |
SETTINGS_DIR = "settings" | |
# === Init === | |
if not os.path.exists(SETTINGS_DIR): | |
os.makedirs(SETTINGS_DIR) | |
MODEL = None | |
print(f"🚀 Running on device: {DEVICE}") | |
def get_or_load_model(): | |
global MODEL | |
if MODEL is None: | |
print("Model not loaded, initializing...") | |
MODEL = ChatterboxTTS.from_pretrained(DEVICE) | |
checkpoint_path = hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=T3_CHECKPOINT_FILE, | |
token=os.environ.get("HUGGING_FACE_HUB_TOKEN", "") | |
) | |
t3_state = load_file(checkpoint_path, device="cpu") | |
MODEL.t3.load_state_dict(t3_state) | |
# Position Embeddings erweitern | |
pos_emb_module = MODEL.t3.text_pos_emb | |
old_pos = pos_emb_module.emb.num_embeddings | |
if MAX_CHARS > old_pos: | |
emb_dim = pos_emb_module.emb.embedding_dim | |
new_emb = nn.Embedding(MAX_CHARS, emb_dim) | |
with torch.no_grad(): | |
new_emb.weight[:old_pos] = pos_emb_module.emb.weight | |
pos_emb_module.emb = new_emb | |
print(f"Expanded position embeddings: {old_pos} → {MAX_CHARS}") | |
MODEL.t3.to(DEVICE) | |
MODEL.s3gen.to(DEVICE) | |
print(f"Model loaded. Device: {MODEL.device}") | |
return MODEL | |
try: | |
get_or_load_model() | |
except Exception as e: | |
print(f"CRITICAL: Failed to load model: {e}") | |
def set_seed(seed: int): | |
torch.manual_seed(seed) | |
if DEVICE == "cuda": | |
torch.cuda.manual_seed_all(seed) | |
random.seed(seed) | |
np.random.seed(seed) | |
def split_text_into_chunks(text, max_length=CHUNK_CHAR_LIMIT): | |
sentences = re.split(r'(?<=[.!?]) +', text) | |
chunks = [] | |
chunk = "" | |
for sentence in sentences: | |
if len(chunk) + len(sentence) < max_length: | |
chunk += " " + sentence | |
else: | |
if chunk: | |
chunks.append(chunk.strip()) | |
chunk = sentence | |
if chunk: | |
chunks.append(chunk.strip()) | |
return chunks | |
# === Einstellungen speichern/laden === | |
def list_presets(): | |
return [f[:-5] for f in os.listdir(SETTINGS_DIR) if f.endswith(".json") and f != "last.json"] | |
def load_preset(name): | |
path = os.path.join(SETTINGS_DIR, name + ".json") | |
if os.path.exists(path): | |
with open(path, "r", encoding="utf-8") as f: | |
return json.load(f) | |
return None | |
def save_preset(name, data): | |
path = os.path.join(SETTINGS_DIR, name + ".json") | |
with open(path, "w", encoding="utf-8") as f: | |
json.dump(data, f, indent=2) | |
save_preset("last", data) # Als "zuletzt genutzt" speichern | |
def generate_tts_audio(text_input, audio_prompt_path_input, exaggeration_input, temperature_input, seed_num_input, cfgw_input): | |
model = get_or_load_model() | |
if seed_num_input != 0: | |
set_seed(int(seed_num_input)) | |
full_audio = [] | |
chunks = split_text_into_chunks(text_input[:MAX_CHARS]) | |
print(f"Text wird in {len(chunks)} Teile aufgeteilt…") | |
for i, chunk in enumerate(chunks): | |
print(f"▶️ Teil {i+1}/{len(chunks)}: {chunk[:60]}...") | |
wav = model.generate( | |
chunk, | |
audio_prompt_path=audio_prompt_path_input, | |
exaggeration=exaggeration_input, | |
temperature=temperature_input, | |
cfg_weight=cfgw_input, | |
) | |
full_audio.append(wav.squeeze(0).cpu().numpy()) | |
audio_concat = np.concatenate(full_audio) | |
return (model.sr, audio_concat) | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
gr.Markdown("# 🥔 Kartoffel-TTS (Chatterbox)\nLangtext → Sprachstil mit Profilen") | |
with gr.Row(): | |
with gr.Column(): | |
preset_dropdown = gr.Dropdown(label="🔄 Preset wählen", choices=list_presets(), value=None) | |
preset_name = gr.Textbox(label="📝 Name zum Speichern", value="mein-profil") | |
text = gr.Textbox( | |
value="Hier kannst du einen längeren deutschen Text eingeben…", | |
label=f"Text (max {MAX_CHARS} Zeichen)", | |
max_lines=12 | |
) | |
ref_wav = gr.Audio( | |
sources=["upload", "microphone"], | |
type="filepath", | |
label="Referenz-Audiodatei (optional)", | |
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac" | |
) | |
exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration", value=.5) | |
cfg_weight = gr.Slider(0.2, 1, step=.05, label="CFG/Pace", value=0.3) | |
with gr.Accordion("Weitere Optionen", open=False): | |
seed_num = gr.Number(value=0, label="Zufalls-Seed (0 = zufällig)") | |
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.6) | |
save_btn = gr.Button("💾 Einstellungen speichern") | |
run_btn = gr.Button("🎤 Audio generieren") | |
with gr.Column(): | |
audio_output = gr.Audio(label="🔊 Ergebnis") | |
# Funktionen zuweisen | |
def on_preset_selected(name): | |
if name: | |
p = load_preset(name) | |
if p: | |
return p["exaggeration"], p["temperature"], p["seed"], p["cfg"] | |
return gr.update(), gr.update(), gr.update(), gr.update() | |
preset_dropdown.change( | |
on_preset_selected, | |
inputs=[preset_dropdown], | |
outputs=[exaggeration, temp, seed_num, cfg_weight] | |
) | |
def save_current_settings(name, exaggeration, temperature, seed, cfg): | |
save_preset(name, { | |
"exaggeration": exaggeration, | |
"temperature": temperature, | |
"seed": seed, | |
"cfg": cfg | |
}) | |
return gr.update(choices=list_presets()) | |
save_btn.click( | |
fn=save_current_settings, | |
inputs=[preset_name, exaggeration, temp, seed_num, cfg_weight], | |
outputs=[preset_dropdown] | |
) | |
run_btn.click( | |
fn=generate_tts_audio, | |
inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight], | |
outputs=[audio_output], | |
) | |
# Letztes Profil beim Start laden | |
if os.path.exists(os.path.join(SETTINGS_DIR, "last.json")): | |
last = load_preset("last") | |
if last: | |
exaggeration.value = last["exaggeration"] | |
temp.value = last["temperature"] | |
seed_num.value = last["seed"] | |
cfg_weight.value = last["cfg"] | |
# 👇 ROBUSTER START – wichtig für exe ohne Konsole! | |
demo.launch( | |
quiet=True, | |
show_error=True, | |
prevent_thread_lock=False | |
) | |