Devakumar868 commited on
Commit
0e0768b
·
verified ·
1 Parent(s): d9c827c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -42
app.py CHANGED
@@ -9,25 +9,28 @@ from dia.model import Dia
9
  from dac.utils import load_model as load_dac_model
10
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11
 
12
- # Retrieve HF_TOKEN from Secrets
13
  HF_TOKEN = os.environ["HF_TOKEN"]
14
-
15
- # Automatically shard across 4× L4 GPUs
16
  device_map = "auto"
17
 
18
- # 1. Load Descript Audio Codec (RVQ)
 
 
 
19
  rvq = load_dac_model(tag="latest", model_type="44khz")
20
  rvq.eval()
21
  if torch.cuda.is_available():
22
  rvq = rvq.to("cuda")
23
 
24
- # 2. Load Voice Activity Detection via Pyannote
 
25
  vad_pipe = PyannotePipeline.from_pretrained(
26
  "pyannote/voice-activity-detection",
27
  use_auth_token=HF_TOKEN
28
  )
29
 
30
- # 3. Load Ultravox (speech-to-text + LLM) via Transformers
 
31
  ultravox_pipe = pipeline(
32
  model="fixie-ai/ultravox-v0_4",
33
  trust_remote_code=True,
@@ -35,14 +38,18 @@ ultravox_pipe = pipeline(
35
  torch_dtype=torch.float16
36
  )
37
 
38
- # 4. Load Audio Diffusion model via Diffusers
 
39
  diff_pipe = DiffusionPipeline.from_pretrained(
40
- "teticio/audio-diffusion-instrumental-hiphop-256"
 
41
  ).to("cuda")
42
 
43
- # 5. Load Dia TTS with meta-weight initialization and multi-GPU dispatch
 
44
  with init_empty_weights():
45
  dia = Dia.from_pretrained("nari-labs/Dia-1.6B")
 
46
  dia = load_checkpoint_and_dispatch(
47
  dia,
48
  "nari-labs/Dia-1.6B",
@@ -50,41 +57,92 @@ dia = load_checkpoint_and_dispatch(
50
  dtype=torch.float16
51
  )
52
 
53
- # Inference function
54
- def process_audio(audio):
55
- sr, array = audio
56
- array = array.numpy() if torch.is_tensor(array) else array
57
-
58
- # 2.1 VAD: segment speech regions (not used further here)
59
- _ = vad_pipe(array, sampling_rate=sr)
60
 
61
- # 1.1 RVQ encode/decode for discrete audio tokens
62
- x = torch.tensor(array).unsqueeze(0).to("cuda")
63
- codes = rvq.encode(x)
64
- decoded = rvq.decode(codes).squeeze().cpu().numpy()
65
-
66
- # 3. Ultravox ASR + LLM to generate response text
67
- ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
68
- text = ultra_out.get("text", "")
69
-
70
- # 4. Diffusion-based prosody enhancement
71
- pros = diff_pipe(raw_audio=decoded)["audios"][0]
72
-
73
- # 5. Dia TTS synthesis with neutral emotion tag
74
- tts = dia.generate(f"[emotion:neutral] {text}")
75
- tts_np = tts.squeeze().cpu().numpy()
76
- tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
77
-
78
- return (sr, tts_np), text
79
-
80
- # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  with gr.Blocks(title="Maya AI 📈") as demo:
82
- gr.Markdown("## Maya-AI: Supernatural Conversational Agent")
83
- audio_in = gr.Audio(source="microphone", type="numpy", label="Your Voice")
84
- send_btn = gr.Button("Send")
85
- audio_out = gr.Audio(label="AI’s Response")
86
- text_out = gr.Textbox(label="Generated Text")
87
- send_btn.click(process_audio, inputs=audio_in, outputs=[audio_out, text_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  if __name__ == "__main__":
90
  demo.launch()
 
9
  from dac.utils import load_model as load_dac_model
10
  from accelerate import init_empty_weights, load_checkpoint_and_dispatch
11
 
12
+ # Environment token from HF Secrets
13
  HF_TOKEN = os.environ["HF_TOKEN"]
 
 
14
  device_map = "auto"
15
 
16
+ print("Loading models...")
17
+
18
+ # 1. RVQ Codec (Descript Audio Codec)
19
+ print("Loading RVQ Codec...")
20
  rvq = load_dac_model(tag="latest", model_type="44khz")
21
  rvq.eval()
22
  if torch.cuda.is_available():
23
  rvq = rvq.to("cuda")
24
 
25
+ # 2. Voice Activity Detection
26
+ print("Loading VAD...")
27
  vad_pipe = PyannotePipeline.from_pretrained(
28
  "pyannote/voice-activity-detection",
29
  use_auth_token=HF_TOKEN
30
  )
31
 
32
+ # 3. Ultravox ASR+LLM
33
+ print("Loading Ultravox...")
34
  ultravox_pipe = pipeline(
35
  model="fixie-ai/ultravox-v0_4",
36
  trust_remote_code=True,
 
38
  torch_dtype=torch.float16
39
  )
40
 
41
+ # 4. Audio Diffusion Model
42
+ print("Loading Audio Diffusion...")
43
  diff_pipe = DiffusionPipeline.from_pretrained(
44
+ "teticio/audio-diffusion-instrumental-hiphop-256",
45
+ torch_dtype=torch.float16
46
  ).to("cuda")
47
 
48
+ # 5. Dia TTS Model
49
+ print("Loading Dia TTS...")
50
  with init_empty_weights():
51
  dia = Dia.from_pretrained("nari-labs/Dia-1.6B")
52
+
53
  dia = load_checkpoint_and_dispatch(
54
  dia,
55
  "nari-labs/Dia-1.6B",
 
57
  dtype=torch.float16
58
  )
59
 
60
+ print("All models loaded successfully!")
 
 
 
 
 
 
61
 
62
+ # Audio processing function
63
+ def process_audio(audio):
64
+ try:
65
+ if audio is None:
66
+ return None, "No audio input provided"
67
+
68
+ sr, array = audio
69
+
70
+ # Ensure numpy array
71
+ if torch.is_tensor(array):
72
+ array = array.numpy()
73
+
74
+ # Voice Activity Detection
75
+ vad_result = vad_pipe({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
76
+
77
+ # RVQ encode/decode for audio compression
78
+ audio_tensor = torch.tensor(array).unsqueeze(0)
79
+ if torch.cuda.is_available():
80
+ audio_tensor = audio_tensor.to("cuda")
81
+ codes = rvq.encode(audio_tensor)
82
+ decoded = rvq.decode(codes).squeeze().cpu().numpy()
83
+
84
+ # Ultravox ASR + LLM processing
85
+ ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
86
+ text = ultra_out.get("text", "I understand your audio input.")
87
+
88
+ # Audio diffusion for prosody enhancement
89
+ try:
90
+ prosody_result = diff_pipe(raw_audio=decoded)
91
+ if "audios" in prosody_result:
92
+ prosody_audio = prosody_result["audios"][0]
93
+ else:
94
+ prosody_audio = decoded
95
+ except Exception as e:
96
+ print(f"Diffusion processing error: {e}")
97
+ prosody_audio = decoded
98
+
99
+ # Dia TTS generation
100
+ tts_output = dia.generate(f"[emotion:neutral] {text}")
101
+
102
+ # Convert to numpy and normalize
103
+ if torch.is_tensor(tts_output):
104
+ tts_np = tts_output.squeeze().cpu().numpy()
105
+ else:
106
+ tts_np = tts_output
107
+
108
+ # Normalize audio output
109
+ if len(tts_np) > 0:
110
+ tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
111
+
112
+ return (sr, tts_np), text
113
+
114
+ except Exception as e:
115
+ print(f"Error in process_audio: {e}")
116
+ return None, f"Processing error: {str(e)}"
117
+
118
+ # Gradio Interface
119
  with gr.Blocks(title="Maya AI 📈") as demo:
120
+ gr.Markdown("# Maya-AI: Supernatural Conversational Agent")
121
+ gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
122
+
123
+ with gr.Row():
124
+ with gr.Column():
125
+ audio_in = gr.Audio(
126
+ source="microphone",
127
+ type="numpy",
128
+ label="Record Your Voice"
129
+ )
130
+ send_btn = gr.Button("Send", variant="primary")
131
+
132
+ with gr.Column():
133
+ audio_out = gr.Audio(label="AI Response")
134
+ text_out = gr.Textbox(
135
+ label="Generated Text",
136
+ lines=3,
137
+ placeholder="AI response will appear here..."
138
+ )
139
+
140
+ # Event handler
141
+ send_btn.click(
142
+ fn=process_audio,
143
+ inputs=audio_in,
144
+ outputs=[audio_out, text_out]
145
+ )
146
 
147
  if __name__ == "__main__":
148
  demo.launch()