Staticaliza commited on
Commit
767aa72
Β·
verified Β·
1 Parent(s): 4dbefe2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -66
app.py CHANGED
@@ -1,72 +1,58 @@
1
- # app.py
2
- import os, shlex, subprocess, torch, numpy as np, gradio as gr, torchaudio, spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from zonos.model import Zonos
4
  from zonos.conditioning import make_cond_dict, supported_language_codes
5
 
6
- # ── optional perf wheels (safe to ignore if they fail) ───────────────────────────
7
- cmds = [
8
- "pip install flash-attn --no-build-isolation",
9
- "pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/"
10
- "mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl",
11
- "pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/"
12
- "v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl",
13
- ]
14
- for c in cmds:
15
- try:
16
- subprocess.run(shlex.split(c), check=True)
17
- except subprocess.CalledProcessError:
18
- print("wheel skipped:", c.split()[2 if c.startswith('pip') else -1])
19
-
20
- # ── disable torch.compile: zerogpu lacks full cuda props ────────────────────────
21
- os.environ["TORCH_COMPILE_DISABLE"] = "1"
22
- os.environ["TORCHINDUCTOR_DISABLE"] = "1"
23
- import torch._dynamo; torch._dynamo.disable()
24
-
25
- device = "cuda" # zerogpu maps this transparently
26
- MODEL_NAME = "Zyphra/Zonos-v0.1-transformer" # hybrid commented out for now
27
-
28
- _cached_model: Zonos | None = None
29
- def get_model() -> Zonos:
30
- global _cached_model
31
- if _cached_model is None:
32
- _cached_model = Zonos.from_pretrained(MODEL_NAME, device=device).eval()
33
- return _cached_model
34
-
35
- def _speaker_embed(audio):
36
- if audio is None: return None
37
- sr, wav = audio
38
- if wav.dtype.kind in "iu": wav = wav.astype(np.float32) / np.iinfo(wav.dtype).max
39
- wav = torch.from_numpy(wav).unsqueeze(0)
40
- return get_model().make_speaker_embedding(wav, sr)
41
-
42
- @spaces.GPU
43
- def tts(text, language, speaker_audio,
44
- e1,e2,e3,e4,e5,e6,e7,e8, speaking_rate, pitch_std):
45
- m = get_model()
46
- speaker = _speaker_embed(speaker_audio)
47
- emotion = [e1,e2,e3,e4,e5,e6,e7,e8]
48
- cond = make_cond_dict(
49
- text=text, language=language, speaker=speaker, emotion=emotion,
50
- speaking_rate=float(speaking_rate), pitch_std=float(pitch_std), device=device
51
- )
52
  with torch.no_grad():
53
- codes = m.generate(m.prepare_conditioning(cond))
54
- wav = m.autoencoder.decode(codes)[0].cpu()
55
- return (m.autoencoder.sampling_rate, wav.numpy())
56
-
57
- langs = supported_language_codes # from the library itself
58
 
 
59
  with gr.Blocks() as demo:
60
- txt = gr.Textbox(label="text")
61
- lng = gr.Dropdown(langs, value="en-us", label="language")
62
- aud = gr.Audio(type="numpy", label="speaker ref (optional)")
63
- emos = [gr.Slider(0,1,0.3 if i==0 else 0.0,0.05,label=l)
64
- for i,l in enumerate(["happiness","sadness","disgust","fear",
65
- "surprise","anger","other","neutral"])]
66
- rate = gr.Slider(0,40,15,1,label="speaking_rate")
67
- pitch= gr.Slider(0,400,20,1,label="pitch_std")
68
- out = gr.Audio(label="output")
69
- gr.Button("generate").click(tts,[txt,lng,aud,*emos,rate,pitch],out)
70
-
71
- if __name__ == "__main__":
72
- demo.launch()
 
1
+ import os, shlex, subprocess, torch
2
+
3
+ # extra wheels (safe to skip if they fail)
4
+ for cmd, env in [
5
+ ("pip install flash-attn --no-build-isolation", {"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}),
6
+ ("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", {}),
7
+ ("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", {}),
8
+ ]:
9
+ try: subprocess.run(shlex.split(cmd), env=os.environ | env, check=True)
10
+ except subprocess.CalledProcessError: pass
11
+
12
+ # hard-nuke torch.compile everywhere
13
+ os.environ["TORCH_COMPILE_DISABLE"]="1"
14
+ os.environ["TORCHINDUCTOR_DISABLE"]="1"
15
+ torch._dynamo.disable()
16
+ torch.compile=lambda fn,*a,**k:fn
17
+
18
+ import torchaudio, gradio as gr, spaces, numpy as np
19
  from zonos.model import Zonos
20
  from zonos.conditioning import make_cond_dict, supported_language_codes
21
 
22
+ device="cuda"
23
+ MODEL_NAMES=["Zyphra/Zonos-v0.1-transformer","Zyphra/Zonos-v0.1-hybrid"]
24
+ MODELS={n:Zonos.from_pretrained(n,device=device).eval() for n in MODEL_NAMES}
25
+
26
+ def _spk(model,aud):
27
+ if aud is None: return None
28
+ sr,wav=aud
29
+ if wav.dtype.kind in "iu": wav=wav.astype(np.float32)/np.iinfo(wav.dtype).max
30
+ return model.make_speaker_embedding(torch.from_numpy(wav).unsqueeze(0),sr)
31
+
32
+ @spaces.GPU(duration=120)
33
+ def tts(m,text,lang,speaker,
34
+ h,sad,disg,fear,sur,ang,oth,neu,
35
+ speak,pitch):
36
+ model=MODELS[m]
37
+ emotion=[h,sad,disg,fear,sur,ang,oth,neu]
38
+ cond=make_cond_dict(text=text,language=lang,speaker=_spk(model,speaker),
39
+ emotion=emotion,speaking_rate=float(speak),
40
+ pitch_std=float(pitch),device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  with torch.no_grad():
42
+ codes=model.generate(model.prepare_conditioning(cond))
43
+ wav=model.autoencoder.decode(codes)[0].cpu()
44
+ return (model.autoencoder.sampling_rate,wav.numpy())
 
 
45
 
46
+ langs=supported_language_codes
47
  with gr.Blocks() as demo:
48
+ mc=gr.Dropdown(MODEL_NAMES,value=MODEL_NAMES[0],label="model")
49
+ txt=gr.Textbox(label="text")
50
+ lng=gr.Dropdown(langs,value="en-us",label="language")
51
+ spk=gr.Audio(type="numpy",label="speaker ref")
52
+ emos=[gr.Slider(0,1,0.3 if i==0 else 0.0,0.05,label=l) for i,l in
53
+ enumerate(["happiness","sad","disgust","fear","surprise","anger","other","neutral"])]
54
+ rate=gr.Slider(0,40,15,1,label="speaking_rate")
55
+ pit=gr.Slider(0,400,20,1,label="pitch_std")
56
+ out=gr.Audio(label="output")
57
+ gr.Button("generate").click(tts,[mc,txt,lng,spk,*emos,rate,pit],out)
58
+ if __name__=="__main__": demo.launch()