Devakumar868 commited on
Commit
d9c827c
·
verified ·
1 Parent(s): 27767e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -14
app.py CHANGED
@@ -9,21 +9,25 @@ from dia.model import Dia
9
  from dac.utils import load_model as load_dac_model
10
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11
 
 
12
  HF_TOKEN = os.environ["HF_TOKEN"]
 
 
13
  device_map = "auto"
14
 
15
- # 1. RVQ Codec
16
  rvq = load_dac_model(tag="latest", model_type="44khz")
17
  rvq.eval()
18
- if torch.cuda.is_available(): rvq = rvq.to("cuda")
 
19
 
20
- # 2. VAD
21
  vad_pipe = PyannotePipeline.from_pretrained(
22
  "pyannote/voice-activity-detection",
23
  use_auth_token=HF_TOKEN
24
  )
25
 
26
- # 3. Ultravox
27
  ultravox_pipe = pipeline(
28
  model="fixie-ai/ultravox-v0_4",
29
  trust_remote_code=True,
@@ -31,48 +35,55 @@ ultravox_pipe = pipeline(
31
  torch_dtype=torch.float16
32
  )
33
 
34
- # 4. Audio Diffusion
35
  diff_pipe = DiffusionPipeline.from_pretrained(
36
  "teticio/audio-diffusion-instrumental-hiphop-256"
37
  ).to("cuda")
38
 
39
- # 5. Dia TTS
40
  with init_empty_weights():
41
- dia = Dia.from_pretrained("nari-labs/Dia-1.6B") # no extra kwargs
42
  dia = load_checkpoint_and_dispatch(
43
  dia,
44
  "nari-labs/Dia-1.6B",
45
- device_map="auto",
46
  dtype=torch.float16
47
  )
48
- # Inference
 
49
  def process_audio(audio):
50
  sr, array = audio
51
  array = array.numpy() if torch.is_tensor(array) else array
52
 
 
53
  _ = vad_pipe(array, sampling_rate=sr)
 
 
54
  x = torch.tensor(array).unsqueeze(0).to("cuda")
55
  codes = rvq.encode(x)
56
  decoded = rvq.decode(codes).squeeze().cpu().numpy()
57
 
 
58
  ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
59
  text = ultra_out.get("text", "")
60
 
 
61
  pros = diff_pipe(raw_audio=decoded)["audios"][0]
62
 
 
63
  tts = dia.generate(f"[emotion:neutral] {text}")
64
  tts_np = tts.squeeze().cpu().numpy()
65
  tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
66
 
67
  return (sr, tts_np), text
68
 
69
- # UI
70
  with gr.Blocks(title="Maya AI 📈") as demo:
71
  gr.Markdown("## Maya-AI: Supernatural Conversational Agent")
72
- audio_in = gr.Audio(source="microphone", type="numpy")
73
- send_btn = gr.Button("Send")
74
- audio_out = gr.Audio()
75
- text_out = gr.Textbox()
76
  send_btn.click(process_audio, inputs=audio_in, outputs=[audio_out, text_out])
77
 
78
  if __name__ == "__main__":
 
9
  from dac.utils import load_model as load_dac_model
10
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11
 
12
+ # Retrieve HF_TOKEN from Secrets
13
  HF_TOKEN = os.environ["HF_TOKEN"]
14
+
15
+ # Automatically shard across 4× L4 GPUs
16
  device_map = "auto"
17
 
18
+ # 1. Load Descript Audio Codec (RVQ)
19
  rvq = load_dac_model(tag="latest", model_type="44khz")
20
  rvq.eval()
21
+ if torch.cuda.is_available():
22
+ rvq = rvq.to("cuda")
23
 
24
+ # 2. Load Voice Activity Detection via Pyannote
25
  vad_pipe = PyannotePipeline.from_pretrained(
26
  "pyannote/voice-activity-detection",
27
  use_auth_token=HF_TOKEN
28
  )
29
 
30
+ # 3. Load Ultravox (speech-to-text + LLM) via Transformers
31
  ultravox_pipe = pipeline(
32
  model="fixie-ai/ultravox-v0_4",
33
  trust_remote_code=True,
 
35
  torch_dtype=torch.float16
36
  )
37
 
38
+ # 4. Load Audio Diffusion model via Diffusers
39
  diff_pipe = DiffusionPipeline.from_pretrained(
40
  "teticio/audio-diffusion-instrumental-hiphop-256"
41
  ).to("cuda")
42
 
43
+ # 5. Load Dia TTS with meta-weight initialization and multi-GPU dispatch
44
  with init_empty_weights():
45
+ dia = Dia.from_pretrained("nari-labs/Dia-1.6B")
46
  dia = load_checkpoint_and_dispatch(
47
  dia,
48
  "nari-labs/Dia-1.6B",
49
+ device_map=device_map,
50
  dtype=torch.float16
51
  )
52
+
53
+ # Inference function
54
  def process_audio(audio):
55
  sr, array = audio
56
  array = array.numpy() if torch.is_tensor(array) else array
57
 
58
+ # 2.1 VAD: segment speech regions (not used further here)
59
  _ = vad_pipe(array, sampling_rate=sr)
60
+
61
+ # 1.1 RVQ encode/decode for discrete audio tokens
62
  x = torch.tensor(array).unsqueeze(0).to("cuda")
63
  codes = rvq.encode(x)
64
  decoded = rvq.decode(codes).squeeze().cpu().numpy()
65
 
66
+ # 3. Ultravox ASR + LLM to generate response text
67
  ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
68
  text = ultra_out.get("text", "")
69
 
70
+ # 4. Diffusion-based prosody enhancement
71
  pros = diff_pipe(raw_audio=decoded)["audios"][0]
72
 
73
+ # 5. Dia TTS synthesis with neutral emotion tag
74
  tts = dia.generate(f"[emotion:neutral] {text}")
75
  tts_np = tts.squeeze().cpu().numpy()
76
  tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
77
 
78
  return (sr, tts_np), text
79
 
80
+ # Gradio UI
81
  with gr.Blocks(title="Maya AI 📈") as demo:
82
  gr.Markdown("## Maya-AI: Supernatural Conversational Agent")
83
+ audio_in = gr.Audio(source="microphone", type="numpy", label="Your Voice")
84
+ send_btn = gr.Button("Send")
85
+ audio_out = gr.Audio(label="AI’s Response")
86
+ text_out = gr.Textbox(label="Generated Text")
87
  send_btn.click(process_audio, inputs=audio_in, outputs=[audio_out, text_out])
88
 
89
  if __name__ == "__main__":