gen-k8x8mz2l / app.py
Javedalam's picture
Update app.py
a62db76 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
import soundfile as sf
import tempfile
import spaces
# --- global handles (lazy-loaded) ---
model = None
tokenizer = None
snac_model = None
def load_models(device: str):
"""Load Maya1 and SNAC once, with device-aware dtype."""
global model, tokenizer, snac_model
if tokenizer is None or model is None:
dtype = torch.bfloat16 if device == "cuda" else torch.float32
print(f"[load_models] loading Maya1 (dtype={dtype}, device={device})")
# device_map only on CUDA; on CPU keep None to avoid accelerate errors
device_map = "auto" if device == "cuda" else None
model = AutoModelForCausalLM.from_pretrained(
"maya-research/maya1",
torch_dtype=dtype,
device_map=device_map,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
"maya-research/maya1",
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if snac_model is None:
print("[load_models] loading SNAC 24kHz decoder")
snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
# move later inside handler (after ZeroGPU alloc)
return snac
return None
@spaces.GPU(duration=180)
def generate_speech(text, voice_description, temperature, top_p, max_tokens):
if not text.strip():
raise gr.Error("Enter some text.")
if not voice_description.strip():
voice_description = "Realistic voice with neutral tone and conversational pacing."
# ZeroGPU gives us CUDA during this call
device = "cuda" if torch.cuda.is_available() else "cpu"
# load / ensure models exist
snac_fresh = load_models(device) # returns SNAC if created
global snac_model
if snac_fresh is not None:
snac_model = snac_fresh
# move models to the active device (ZeroGPU alloc happened)
if device == "cuda":
model.to(device)
snac_model.to(device)
# prompt exactly like the model card
prompt = f'<description="{voice_description}"> {text}'
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs.get("attention_mask"),
max_new_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=None,
repetition_penalty=1.1,
)
# SNAC token extraction (7-token frames) — as per model card
generated_ids = outputs[0, inputs["input_ids"].shape[1]:]
snac_tokens = [t.item() for t in generated_ids if 128266 <= t <= 156937]
frames = len(snac_tokens) // 7
if frames == 0:
raise gr.Error("No SNAC tokens generated. Try longer text and max_tokens=1200–1500.")
codes = [[], [], []]
for i in range(frames):
s = snac_tokens[i*7:(i+1)*7]
codes[0].append((s[0]-128266) % 4096)
codes[1].extend([(s[1]-128266) % 4096, (s[4]-128266) % 4096])
codes[2].extend([
(s[2]-128266) % 4096,
(s[3]-128266) % 4096,
(s[5]-128266) % 4096,
(s[6]-128266) % 4096,
])
codes_tensor = [torch.tensor(c, dtype=torch.long, device=device).unsqueeze(0) for c in codes]
with torch.inference_mode():
audio = snac_model.decoder(snac_model.quantizer.from_codes(codes_tensor))[0, 0].cpu().numpy()
# write to wav; return filepath for gr.Audio(type="filepath")
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
sf.write(f.name, audio, 24000)
return f.name
# ------------------- UI -------------------
voice_presets = {
"Male - American": "Realistic male voice in the 30s age with american accent. Normal pitch, warm timbre, conversational pacing.",
"Female - British": "Clear female voice in the 20s age with British accent. Pleasant tone, articulate delivery, moderate pacing.",
"Male - Deep": "Deep male voice with authoritative tone. Low pitch, resonant timbre, steady pacing.",
"Female - Energetic": "Energetic female voice with enthusiastic tone. Higher pitch, bright timbre, upbeat pacing.",
"Neutral - Professional": "Professional neutral voice with clear articulation. Balanced pitch, crisp tone, measured pacing.",
"Custom": ""
}
def update_voice_description(preset): return voice_presets.get(preset, "")
with gr.Blocks(theme=gr.themes.Soft(), title="Maya1 Text-to-Speech") as demo:
gr.HTML("""
<div style="text-align:center;padding:16px">
<h1>🎙️ Maya1 Text-to-Speech</h1>
<p style="color:#666">Generate emotional, realistic speech with natural-language voice design</p>
<p style="font-size:12px;color:#28a745">⚡ ZeroGPU inference</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Text to Speak",
value="Hello! This is Maya1 <laugh> the best open source voice AI model with emotions.",
lines=5,
)
voice_preset = gr.Dropdown(choices=list(voice_presets.keys()),
value="Male - American",
label="Voice Preset")
voice_description = gr.Textbox(
label="Voice Description",
value=voice_presets["Male - American"],
lines=3,
)
with gr.Accordion("Advanced", open=False):
temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(0.5, 1.0, value=0.9, step=0.05, label="Top-p")
max_tokens = gr.Slider(500, 2000, value=1000, step=100, label="Max tokens")
generate_btn = gr.Button("🎤 Generate Speech", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="Generated Speech", type="filepath", interactive=False)
voice_preset.change(fn=update_voice_description, inputs=[voice_preset], outputs=[voice_description])
generate_btn.click(fn=generate_speech,
inputs=[text_input, voice_description, temperature, top_p, max_tokens],
outputs=[audio_output])
# Register an explicit API endpoint so Spaces never shows “No API found”
gr.api(fn=generate_speech, name="generate_speech")
if __name__ == "__main__":
demo.queue()
demo.launch(show_error=True)