Devakumar868 commited on
Commit
55c39a0
Β·
verified Β·
1 Parent(s): 018a337

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -99
app.py CHANGED
@@ -2,138 +2,86 @@ import os
2
  import gradio as gr
3
  import torch
4
  import numpy as np
5
- from transformers import pipeline
6
  from diffusers import DiffusionPipeline
7
  from pyannote.audio import Pipeline as PyannotePipeline
8
  from dia.model import Dia
9
  from dac.utils import load_model as load_dac_model
10
 
11
- # Environment token from HF Secrets
12
- HF_TOKEN = os.environ["HF_TOKEN"]
 
13
 
14
- print("Loading models...")
15
-
16
- # 1. Load RVQ Codec (Descript Audio Codec)
17
  print("Loading RVQ Codec...")
18
  rvq = load_dac_model(tag="latest", model_type="44khz")
19
  rvq.eval()
20
  if torch.cuda.is_available():
21
  rvq = rvq.to("cuda")
22
 
23
- # 2. Load Voice Activity Detection
24
- print("Loading VAD...")
25
  vad_pipe = PyannotePipeline.from_pretrained(
26
  "pyannote/voice-activity-detection",
27
  use_auth_token=HF_TOKEN
28
  )
29
 
30
- # 3. Load Ultravox ASR+LLM
31
- print("Loading Ultravox...")
32
  ultravox_pipe = pipeline(
33
  model="fixie-ai/ultravox-v0_4",
34
  trust_remote_code=True,
35
- device_map="auto",
36
  torch_dtype=torch.float16
37
  )
38
 
39
- # 4. Load Audio Diffusion Model
40
- print("Loading Audio Diffusion...")
41
  diff_pipe = DiffusionPipeline.from_pretrained(
42
  "teticio/audio-diffusion-instrumental-hiphop-256",
43
  torch_dtype=torch.float16
44
  ).to("cuda")
45
 
46
- # 5. Load Dia TTS Model (WITHOUT meta tensor approach)
47
- print("Loading Dia TTS...")
48
- # Direct loading without init_empty_weights to avoid meta tensor issues
49
- dia = Dia.from_pretrained("nari-labs/Dia-1.6B")
 
 
 
50
 
51
  print("All models loaded successfully!")
52
 
53
- # Audio processing function
54
  def process_audio(audio):
55
- try:
56
- if audio is None:
57
- return None, "No audio input provided"
58
-
59
- sr, array = audio
60
-
61
- # Ensure numpy array
62
- if torch.is_tensor(array):
63
- array = array.numpy()
64
-
65
- # Voice Activity Detection
66
- vad_result = vad_pipe({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
67
-
68
- # RVQ encode/decode for audio compression
69
- audio_tensor = torch.tensor(array).unsqueeze(0)
70
- if torch.cuda.is_available():
71
- audio_tensor = audio_tensor.to("cuda")
72
- codes = rvq.encode(audio_tensor)
73
- decoded = rvq.decode(codes).squeeze().cpu().numpy()
74
-
75
- # Ultravox ASR + LLM processing
76
- ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
77
- text = ultra_out.get("text", "I understand your audio input.")
78
-
79
- # Audio diffusion for prosody enhancement
80
- try:
81
- prosody_result = diff_pipe(raw_audio=decoded)
82
- if "audios" in prosody_result:
83
- prosody_audio = prosody_result["audios"][0]
84
- else:
85
- prosody_audio = decoded
86
- except Exception as e:
87
- print(f"Diffusion processing error: {e}")
88
- prosody_audio = decoded
89
-
90
- # Dia TTS generation
91
- tts_output = dia.generate(f"[emotion:neutral] {text}")
92
-
93
- # Convert to numpy and normalize
94
- if torch.is_tensor(tts_output):
95
- tts_np = tts_output.squeeze().cpu().numpy()
96
- else:
97
- tts_np = tts_output
98
-
99
- # Normalize audio output
100
- if len(tts_np) > 0:
101
- tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
102
-
103
- return (sr, tts_np), text
104
-
105
- except Exception as e:
106
- print(f"Error in process_audio: {e}")
107
- return None, f"Processing error: {str(e)}"
108
-
109
- # Gradio Interface
110
  with gr.Blocks(title="Maya AI πŸ“ˆ") as demo:
111
- gr.Markdown("# Maya-AI: Supernatural Conversational Agent")
112
- gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
113
-
114
- with gr.Row():
115
- with gr.Column():
116
- audio_in = gr.Audio(
117
- sources=["microphone"],
118
- type="numpy",
119
- label="Record Your Voice"
120
- )
121
- send_btn = gr.Button("Send", variant="primary")
122
-
123
- with gr.Column():
124
- audio_out = gr.Audio(label="AI Response")
125
- text_out = gr.Textbox(
126
- label="Generated Text",
127
- lines=3,
128
- placeholder="AI response will appear here..."
129
- )
130
-
131
- # Event handler
132
- send_btn.click(
133
- fn=process_audio,
134
- inputs=audio_in,
135
- outputs=[audio_out, text_out]
136
- )
137
 
138
  if __name__ == "__main__":
139
  demo.launch()
 
2
  import gradio as gr
3
  import torch
4
  import numpy as np
5
+ from transformers import pipeline, AutoModel
6
  from diffusers import DiffusionPipeline
7
  from pyannote.audio import Pipeline as PyannotePipeline
8
  from dia.model import Dia
9
  from dac.utils import load_model as load_dac_model
10
 
11
+ # 1. Retrieve HF token and set device mapping
12
+ HF_TOKEN = os.environ["HF_TOKEN"]
13
+ device_map = "auto" # auto-shard models across 4Γ—L4 GPUs
14
 
 
 
 
15
  print("Loading RVQ Codec...")
16
  rvq = load_dac_model(tag="latest", model_type="44khz")
17
  rvq.eval()
18
  if torch.cuda.is_available():
19
  rvq = rvq.to("cuda")
20
 
21
+ print("Loading VAD pipeline...")
 
22
  vad_pipe = PyannotePipeline.from_pretrained(
23
  "pyannote/voice-activity-detection",
24
  use_auth_token=HF_TOKEN
25
  )
26
 
27
+ print("Loading Ultravox pipeline...")
 
28
  ultravox_pipe = pipeline(
29
  model="fixie-ai/ultravox-v0_4",
30
  trust_remote_code=True,
31
+ device_map=device_map,
32
  torch_dtype=torch.float16
33
  )
34
 
35
+ print("Loading Audio Diffusion model...")
 
36
  diff_pipe = DiffusionPipeline.from_pretrained(
37
  "teticio/audio-diffusion-instrumental-hiphop-256",
38
  torch_dtype=torch.float16
39
  ).to("cuda")
40
 
41
+ print("Loading Dia TTS (sharded across GPUs)...")
42
+ dia = Dia.from_pretrained(
43
+ "nari-labs/Dia-1.6B",
44
+ device_map=device_map,
45
+ torch_dtype=torch.float16,
46
+ trust_remote_code=True
47
+ )
48
 
49
  print("All models loaded successfully!")
50
 
 
51
  def process_audio(audio):
52
+ sr, array = audio
53
+ array = array.numpy() if torch.is_tensor(array) else array
54
+
55
+ # 1. Voice activity detection
56
+ vad_pipe({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
57
+
58
+ # 2. RVQ encode/decode
59
+ x = torch.tensor(array).unsqueeze(0).to("cuda")
60
+ codes = rvq.encode(x)
61
+ decoded = rvq.decode(codes).squeeze().cpu().numpy()
62
+
63
+ # 3. Ultravox ASR β†’ text
64
+ out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
65
+ text = out.get("text", "")
66
+
67
+ # 4. Prosody diffusion
68
+ pros = diff_pipe(raw_audio=decoded)["audios"][0]
69
+
70
+ # 5. Dia TTS synthesis
71
+ tts = dia.generate(f"[emotion:neutral] {text}")
72
+ tts_np = tts.squeeze().cpu().numpy()
73
+ tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95 if tts_np.size else tts_np
74
+
75
+ return (sr, tts_np), text
76
+
77
+ # Gradio UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  with gr.Blocks(title="Maya AI πŸ“ˆ") as demo:
79
+ gr.Markdown("## Maya-AI Supernatural Conversational Agent")
80
+ audio_in = gr.Audio(source="microphone", type="numpy", label="Your Voice")
81
+ send_btn = gr.Button("Send")
82
+ audio_out = gr.Audio(label="AI Response")
83
+ text_out = gr.Textbox(label="Generated Text")
84
+ send_btn.click(process_audio, inputs=audio_in, outputs=[audio_out, text_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  if __name__ == "__main__":
87
  demo.launch()