Devakumar868 commited on
Commit
42e6e01
·
verified ·
1 Parent(s): 1a24747

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -42
app.py CHANGED
@@ -1,35 +1,50 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
4
  from transformers import pipeline, AutoProcessor, CsmForConditionalGeneration
5
- from pyannote.audio import Pipeline as VAD
6
  from dia.model import Dia
7
  from dac.utils import load_model as load_dac_model
8
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
9
 
10
- # 2.1: Initialize Accelerator for 4×L4 GPU distribution
11
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
12
- device_map = "auto" # accelerate automatically shards across 4 L4 GPUs[2]
 
 
13
 
14
- # 2.2: Load Descript Audio Codec (RVQ) at startup
 
 
 
15
  rvq = load_dac_model(tag="latest", model_type="44khz")
16
  rvq.eval()
17
  if torch.cuda.is_available():
18
  rvq = rvq.to("cuda")
19
 
20
- # 2.3: Load VAD pipeline
21
- vad = VAD.from_pretrained("pyannote/voice-activity-detection")
 
 
 
 
 
22
 
23
- # 2.4: Load Ultravox via audio-text-to-text pipeline
 
 
 
 
24
  ultravox_pipe = pipeline(
25
- "audio-text-to-text",
26
  model="fixie-ai/ultravox-v0_4",
27
  trust_remote_code=True,
28
  device_map=device_map,
29
  torch_dtype=torch.float16
30
  )
31
 
32
- # 2.5: Load Diffusion model
 
33
  diff_pipe = pipeline(
34
  "audio-to-audio",
35
  model="teticio/audio-diffusion-instrumental-hiphop-256",
@@ -38,44 +53,80 @@ diff_pipe = pipeline(
38
  torch_dtype=torch.float16
39
  )
40
 
41
- # 2.6: Load Dia TTS with multi-GPU dispatch
 
42
  with init_empty_weights():
43
- dia = Dia.from_pretrained("nari-labs/Dia-1.6B", torch_dtype=torch.float16, trust_remote_code=True)
 
 
 
 
44
  dia = load_checkpoint_and_dispatch(
45
- dia, "nari-labs/Dia-1.6B", device_map=device_map, dtype=torch.float16
 
 
 
46
  )
47
 
48
- # 2.7: Gradio inference function
 
 
49
  def process_audio(audio):
50
- sr, array = audio["sampling_rate"], audio["array"]
51
- # VAD segmentation
52
- speech_segments = vad({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
53
- # RVQ encode/decode
54
- audio_tensor = torch.tensor(array).unsqueeze(0)
55
- if torch.cuda.is_available():
56
- audio_tensor = audio_tensor.to("cuda")
57
- codes = rvq.encode(audio_tensor)
58
- decoded = rvq.decode(codes)
59
- array = decoded.squeeze().cpu().numpy()
60
- # Ultravox ASR→LLM
61
- ultra_out = ultravox_pipe({"array": array, "sampling_rate": sr})
62
- text = ultra_out["text"]
63
- # Diffusion-based prosody enhancement
64
- prosody_audio = diff_pipe({"array": decoded.cpu().numpy(), "sampling_rate": sr})["array"][0]
65
- # Dia TTS
66
- tts_audio = dia.generate(f"[emotion:neutral] {text}")
67
- tts_np = tts_audio.squeeze().cpu().numpy()
68
- # Normalize
69
- tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
70
- return (sr, tts_np), text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # 2.8: Gradio UI
73
- with gr.Blocks() as demo:
74
- gr.Markdown("## Supernatural Speech AI Agent")
75
- audio_in = gr.Audio(source="microphone", type="numpy", label="Record Your Voice")
76
- btn = gr.Button("Send")
77
- audio_out = gr.Audio(label="AI Response")
78
- txt_out = gr.Textbox(label="Transcribed & Generated Text")
 
 
 
 
 
 
 
79
  btn.click(fn=process_audio, inputs=audio_in, outputs=[audio_out, txt_out])
80
 
81
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ import os
5
  from transformers import pipeline, AutoProcessor, CsmForConditionalGeneration
6
+ from pyannote.audio import Model, Inference
7
  from dia.model import Dia
8
  from dac.utils import load_model as load_dac_model
9
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
10
 
11
+ # Access HF_TOKEN from environment variables (Secrets)
12
+ HF_TOKEN = os.environ.get("HF_TOKEN")
13
+
14
+ # Device mapping for 4× L4 GPU distribution
15
+ device_map = "auto"
16
 
17
+ print("Loading models...")
18
+
19
+ # Load Descript Audio Codec (RVQ) at startup
20
+ print("Loading RVQ Codec...")
21
  rvq = load_dac_model(tag="latest", model_type="44khz")
22
  rvq.eval()
23
  if torch.cuda.is_available():
24
  rvq = rvq.to("cuda")
25
 
26
+ # Load segmentation model with authentication
27
+ print("Loading Segmentation Model...")
28
+ seg_model = Model.from_pretrained(
29
+ "pyannote/segmentation",
30
+ use_auth_token=HF_TOKEN
31
+ )
32
+ seg_inference = Inference(seg_model, device=0 if torch.cuda.is_available() else -1)
33
 
34
+ # Use segmentation model for VAD
35
+ vad = seg_inference
36
+
37
+ # Load Ultravox via generic pipeline (without specifying task)
38
+ print("Loading Ultravox...")
39
  ultravox_pipe = pipeline(
 
40
  model="fixie-ai/ultravox-v0_4",
41
  trust_remote_code=True,
42
  device_map=device_map,
43
  torch_dtype=torch.float16
44
  )
45
 
46
+ # Load Diffusion model
47
+ print("Loading Diffusion Model...")
48
  diff_pipe = pipeline(
49
  "audio-to-audio",
50
  model="teticio/audio-diffusion-instrumental-hiphop-256",
 
53
  torch_dtype=torch.float16
54
  )
55
 
56
+ # Load Dia TTS with multi-GPU dispatch
57
+ print("Loading Dia TTS...")
58
  with init_empty_weights():
59
+ dia = Dia.from_pretrained(
60
+ "nari-labs/Dia-1.6B",
61
+ torch_dtype=torch.float16,
62
+ trust_remote_code=True
63
+ )
64
  dia = load_checkpoint_and_dispatch(
65
+ dia,
66
+ "nari-labs/Dia-1.6B",
67
+ device_map=device_map,
68
+ dtype=torch.float16
69
  )
70
 
71
+ print("All models loaded successfully!")
72
+
73
+ # Gradio inference function
74
  def process_audio(audio):
75
+ try:
76
+ if audio is None:
77
+ return None, "No audio input provided"
78
+
79
+ sr, array = audio
80
+
81
+ # Ensure audio is numpy array
82
+ if torch.is_tensor(array):
83
+ array = array.numpy()
84
+
85
+ # VAD segmentation
86
+ segments = vad({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
87
+
88
+ # RVQ encode/decode
89
+ audio_tensor = torch.tensor(array).unsqueeze(0)
90
+ if torch.cuda.is_available():
91
+ audio_tensor = audio_tensor.to("cuda")
92
+ codes = rvq.encode(audio_tensor)
93
+ decoded = rvq.decode(codes)
94
+ array = decoded.squeeze().cpu().numpy()
95
+
96
+ # Ultravox ASR→LLM
97
+ ultra_out = ultravox_pipe({"array": array, "sampling_rate": sr})
98
+ text = ultra_out.get("text", "I understand your audio input.")
99
+
100
+ # Diffusion-based prosody enhancement
101
+ prosody_audio = diff_pipe({"array": decoded.cpu().numpy(), "sampling_rate": sr})["array"][0]
102
+
103
+ # Dia TTS
104
+ tts_audio = dia.generate(f"[emotion:neutral] {text}")
105
+ tts_np = tts_audio.squeeze().cpu().numpy()
106
+
107
+ # Normalize
108
+ tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
109
+
110
+ return (sr, tts_np), text
111
+
112
+ except Exception as e:
113
+ print(f"Error in process_audio: {e}")
114
+ return None, f"Processing error: {str(e)}"
115
 
116
+ # Gradio UI
117
+ with gr.Blocks(title="Maya-AI: Supernatural Speech Agent") as demo:
118
+ gr.Markdown("# Maya-AI: Supernatural Speech Agent")
119
+ gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
120
+
121
+ with gr.Row():
122
+ with gr.Column():
123
+ audio_in = gr.Audio(source="microphone", type="numpy", label="Record Your Voice")
124
+ btn = gr.Button("Send", variant="primary")
125
+
126
+ with gr.Column():
127
+ audio_out = gr.Audio(label="AI Response")
128
+ txt_out = gr.Textbox(label="Transcribed & Generated Text", lines=3)
129
+
130
  btn.click(fn=process_audio, inputs=audio_in, outputs=[audio_out, txt_out])
131
 
132
  if __name__ == "__main__":