Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from snac import SNAC | |
| import soundfile as sf | |
| import tempfile | |
| import spaces | |
| # --- global handles (lazy-loaded) --- | |
| model = None | |
| tokenizer = None | |
| snac_model = None | |
| def load_models(device: str): | |
| """Load Maya1 and SNAC once, with device-aware dtype.""" | |
| global model, tokenizer, snac_model | |
| if tokenizer is None or model is None: | |
| dtype = torch.bfloat16 if device == "cuda" else torch.float32 | |
| print(f"[load_models] loading Maya1 (dtype={dtype}, device={device})") | |
| # device_map only on CUDA; on CPU keep None to avoid accelerate errors | |
| device_map = "auto" if device == "cuda" else None | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "maya-research/maya1", | |
| torch_dtype=dtype, | |
| device_map=device_map, | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "maya-research/maya1", | |
| trust_remote_code=True, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if snac_model is None: | |
| print("[load_models] loading SNAC 24kHz decoder") | |
| snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval() | |
| # move later inside handler (after ZeroGPU alloc) | |
| return snac | |
| return None | |
| def generate_speech(text, voice_description, temperature, top_p, max_tokens): | |
| if not text.strip(): | |
| raise gr.Error("Enter some text.") | |
| if not voice_description.strip(): | |
| voice_description = "Realistic voice with neutral tone and conversational pacing." | |
| # ZeroGPU gives us CUDA during this call | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # load / ensure models exist | |
| snac_fresh = load_models(device) # returns SNAC if created | |
| global snac_model | |
| if snac_fresh is not None: | |
| snac_model = snac_fresh | |
| # move models to the active device (ZeroGPU alloc happened) | |
| if device == "cuda": | |
| model.to(device) | |
| snac_model.to(device) | |
| # prompt exactly like the model card | |
| prompt = f'<description="{voice_description}"> {text}' | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs.get("attention_mask"), | |
| max_new_tokens=int(max_tokens), | |
| temperature=float(temperature), | |
| top_p=float(top_p), | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=None, | |
| repetition_penalty=1.1, | |
| ) | |
| # SNAC token extraction (7-token frames) — as per model card | |
| generated_ids = outputs[0, inputs["input_ids"].shape[1]:] | |
| snac_tokens = [t.item() for t in generated_ids if 128266 <= t <= 156937] | |
| frames = len(snac_tokens) // 7 | |
| if frames == 0: | |
| raise gr.Error("No SNAC tokens generated. Try longer text and max_tokens=1200–1500.") | |
| codes = [[], [], []] | |
| for i in range(frames): | |
| s = snac_tokens[i*7:(i+1)*7] | |
| codes[0].append((s[0]-128266) % 4096) | |
| codes[1].extend([(s[1]-128266) % 4096, (s[4]-128266) % 4096]) | |
| codes[2].extend([ | |
| (s[2]-128266) % 4096, | |
| (s[3]-128266) % 4096, | |
| (s[5]-128266) % 4096, | |
| (s[6]-128266) % 4096, | |
| ]) | |
| codes_tensor = [torch.tensor(c, dtype=torch.long, device=device).unsqueeze(0) for c in codes] | |
| with torch.inference_mode(): | |
| audio = snac_model.decoder(snac_model.quantizer.from_codes(codes_tensor))[0, 0].cpu().numpy() | |
| # write to wav; return filepath for gr.Audio(type="filepath") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: | |
| sf.write(f.name, audio, 24000) | |
| return f.name | |
| # ------------------- UI ------------------- | |
| voice_presets = { | |
| "Male - American": "Realistic male voice in the 30s age with american accent. Normal pitch, warm timbre, conversational pacing.", | |
| "Female - British": "Clear female voice in the 20s age with British accent. Pleasant tone, articulate delivery, moderate pacing.", | |
| "Male - Deep": "Deep male voice with authoritative tone. Low pitch, resonant timbre, steady pacing.", | |
| "Female - Energetic": "Energetic female voice with enthusiastic tone. Higher pitch, bright timbre, upbeat pacing.", | |
| "Neutral - Professional": "Professional neutral voice with clear articulation. Balanced pitch, crisp tone, measured pacing.", | |
| "Custom": "" | |
| } | |
| def update_voice_description(preset): return voice_presets.get(preset, "") | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Maya1 Text-to-Speech") as demo: | |
| gr.HTML(""" | |
| <div style="text-align:center;padding:16px"> | |
| <h1>🎙️ Maya1 Text-to-Speech</h1> | |
| <p style="color:#666">Generate emotional, realistic speech with natural-language voice design</p> | |
| <p style="font-size:12px;color:#28a745">⚡ ZeroGPU inference</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Text to Speak", | |
| value="Hello! This is Maya1 <laugh> the best open source voice AI model with emotions.", | |
| lines=5, | |
| ) | |
| voice_preset = gr.Dropdown(choices=list(voice_presets.keys()), | |
| value="Male - American", | |
| label="Voice Preset") | |
| voice_description = gr.Textbox( | |
| label="Voice Description", | |
| value=voice_presets["Male - American"], | |
| lines=3, | |
| ) | |
| with gr.Accordion("Advanced", open=False): | |
| temperature = gr.Slider(0.1, 1.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.5, 1.0, value=0.9, step=0.05, label="Top-p") | |
| max_tokens = gr.Slider(500, 2000, value=1000, step=100, label="Max tokens") | |
| generate_btn = gr.Button("🎤 Generate Speech", variant="primary") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio(label="Generated Speech", type="filepath", interactive=False) | |
| voice_preset.change(fn=update_voice_description, inputs=[voice_preset], outputs=[voice_description]) | |
| generate_btn.click(fn=generate_speech, | |
| inputs=[text_input, voice_description, temperature, top_p, max_tokens], | |
| outputs=[audio_output]) | |
| # Register an explicit API endpoint so Spaces never shows “No API found” | |
| gr.api(fn=generate_speech, name="generate_speech") | |
| if __name__ == "__main__": | |
| demo.queue() | |
| demo.launch(show_error=True) | |