Devakumar868 commited on
Commit
1a24747
·
verified ·
1 Parent(s): 149c25d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -131
app.py CHANGED
@@ -1,142 +1,82 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
3
  import torch
4
  import numpy as np
 
5
  from pyannote.audio import Pipeline as VAD
6
- import dac
 
 
7
 
8
- # Load models with proper error handling
9
- def load_models():
10
- try:
11
- # Ultravox via transformers (no separate package needed)
12
- ultra_proc = AutoProcessor.from_pretrained("fixie-ai/ultravox-v0_4", trust_remote_code=True)
13
- ultra_model = AutoModelForCausalLM.from_pretrained("fixie-ai/ultravox-v0_4", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
14
-
15
- # Speech emotion recognition via transformers pipeline
16
- emotion_pipeline = pipeline("audio-classification", model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", device=0 if torch.cuda.is_available() else -1)
17
-
18
- # Audio diffusion (using transformers instead of torch.hub for HF compatibility)
19
- from diffusers import DiffusionPipeline
20
- diff_pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-instrumental-hiphop-256")
21
-
22
- # Descript Audio Codec
23
- from dac.utils import load_model as load_dac_model
24
- rvq = load_dac_model(tag="latest", model_type="44khz")
25
- rvq.eval()
26
- if torch.cuda.is_available():
27
- rvq = rvq.to("cuda")
28
-
29
- # VAD
30
- vad = VAD.from_pretrained("pyannote/voice-activity-detection")
31
-
32
- # Dia TTS
33
- from dia.model import Dia
34
- dia = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
35
-
36
- return ultra_proc, ultra_model, emotion_pipeline, diff_pipe, rvq, vad, dia
37
-
38
- except Exception as e:
39
- print(f"Error loading models: {e}")
40
- return None, None, None, None, None, None, None
41
 
42
- # Initialize models
43
- ultra_proc, ultra_model, emotion_pipeline, diff_pipe, rvq, vad, dia = load_models()
 
 
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def process_audio(audio):
46
- try:
47
- if audio is None:
48
- return None, "No audio input provided"
49
-
50
- # Convert audio to proper format
51
- audio_array = audio[1] if isinstance(audio, tuple) else audio["array"]
52
- sample_rate = audio[0] if isinstance(audio, tuple) else audio["sampling_rate"]
53
-
54
- # Ensure audio is numpy array
55
- if torch.is_tensor(audio_array):
56
- audio_array = audio_array.numpy()
57
-
58
- # VAD processing
59
- if vad is not None:
60
- speech_segments = vad({"waveform": torch.from_numpy(audio_array).unsqueeze(0), "sample_rate": sample_rate})
61
-
62
- # Emotion recognition
63
- emotion_result = "neutral"
64
- if emotion_pipeline is not None:
65
- try:
66
- emotion_pred = emotion_pipeline(audio_array, sampling_rate=sample_rate)
67
- emotion_result = emotion_pred[0]["label"] if emotion_pred else "neutral"
68
- except:
69
- emotion_result = "neutral"
70
-
71
- # RVQ encode/decode
72
- if rvq is not None:
73
- try:
74
- audio_tensor = torch.from_numpy(audio_array).float().unsqueeze(0)
75
- if torch.cuda.is_available():
76
- audio_tensor = audio_tensor.to("cuda")
77
- encoded = rvq.encode(audio_tensor)
78
- decoded_audio = rvq.decode(encoded)
79
- if torch.cuda.is_available():
80
- decoded_audio = decoded_audio.cpu()
81
- audio_array = decoded_audio.squeeze().numpy()
82
- except Exception as e:
83
- print(f"RVQ processing error: {e}")
84
-
85
- # Ultravox generation
86
- response_text = "I understand your audio input."
87
- if ultra_proc is not None and ultra_model is not None:
88
- try:
89
- inputs = ultra_proc(audio_array, sampling_rate=sample_rate, return_tensors="pt")
90
- if torch.cuda.is_available():
91
- inputs = {k: v.to("cuda") for k, v in inputs.items()}
92
-
93
- with torch.no_grad():
94
- outputs = ultra_model.generate(**inputs, max_new_tokens=50)
95
- response_text = ultra_proc.decode(outputs[0], skip_special_tokens=True)
96
- except Exception as e:
97
- print(f"Ultravox generation error: {e}")
98
- response_text = f"Detected emotion: {emotion_result}"
99
-
100
- # TTS generation
101
- output_audio = None
102
- if dia is not None:
103
- try:
104
- tts_text = f"[emotion:{emotion_result}] {response_text}"
105
- output_audio = dia.generate(tts_text)
106
- if torch.is_tensor(output_audio):
107
- output_audio = output_audio.cpu().numpy()
108
- # Normalize audio
109
- if output_audio is not None:
110
- output_audio = output_audio / np.max(np.abs(output_audio)) * 0.95
111
- except Exception as e:
112
- print(f"TTS generation error: {e}")
113
-
114
- return (sample_rate, output_audio) if output_audio is not None else None, response_text
115
-
116
- except Exception as e:
117
- return None, f"Processing error: {str(e)}"
118
 
119
- # Create Gradio interface
120
- with gr.Blocks(title="Supernatural Speech AI") as demo:
121
- gr.Markdown("# Supernatural Speech AI Agent")
122
- gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
123
-
124
- with gr.Row():
125
- with gr.Column():
126
- audio_input = gr.Audio(sources=["microphone"], type="numpy", label="Record Audio")
127
- process_btn = gr.Button("Process Audio", variant="primary")
128
-
129
- with gr.Column():
130
- audio_output = gr.Audio(label="AI Response")
131
- text_output = gr.Textbox(label="Response Text", lines=3)
132
-
133
- conversation_history = gr.State([])
134
-
135
- process_btn.click(
136
- fn=process_audio,
137
- inputs=[audio_input],
138
- outputs=[audio_output, text_output]
139
- )
140
 
141
  if __name__ == "__main__":
142
- demo.queue(concurrency_limit=20, max_size=50).launch()
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
+ from transformers import pipeline, AutoProcessor, CsmForConditionalGeneration
5
  from pyannote.audio import Pipeline as VAD
6
+ from dia.model import Dia
7
+ from dac.utils import load_model as load_dac_model
8
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
9
 
10
+ # 2.1: Initialize Accelerator for 4×L4 GPU distribution
11
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
12
+ device_map = "auto" # accelerate automatically shards across 4 L4 GPUs[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # 2.2: Load Descript Audio Codec (RVQ) at startup
15
+ rvq = load_dac_model(tag="latest", model_type="44khz")
16
+ rvq.eval()
17
+ if torch.cuda.is_available():
18
+ rvq = rvq.to("cuda")
19
 
20
+ # 2.3: Load VAD pipeline
21
+ vad = VAD.from_pretrained("pyannote/voice-activity-detection")
22
+
23
+ # 2.4: Load Ultravox via audio-text-to-text pipeline
24
+ ultravox_pipe = pipeline(
25
+ "audio-text-to-text",
26
+ model="fixie-ai/ultravox-v0_4",
27
+ trust_remote_code=True,
28
+ device_map=device_map,
29
+ torch_dtype=torch.float16
30
+ )
31
+
32
+ # 2.5: Load Diffusion model
33
+ diff_pipe = pipeline(
34
+ "audio-to-audio",
35
+ model="teticio/audio-diffusion-instrumental-hiphop-256",
36
+ trust_remote_code=True,
37
+ device_map=device_map,
38
+ torch_dtype=torch.float16
39
+ )
40
+
41
+ # 2.6: Load Dia TTS with multi-GPU dispatch
42
+ with init_empty_weights():
43
+ dia = Dia.from_pretrained("nari-labs/Dia-1.6B", torch_dtype=torch.float16, trust_remote_code=True)
44
+ dia = load_checkpoint_and_dispatch(
45
+ dia, "nari-labs/Dia-1.6B", device_map=device_map, dtype=torch.float16
46
+ )
47
+
48
+ # 2.7: Gradio inference function
49
  def process_audio(audio):
50
+ sr, array = audio["sampling_rate"], audio["array"]
51
+ # VAD segmentation
52
+ speech_segments = vad({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
53
+ # RVQ encode/decode
54
+ audio_tensor = torch.tensor(array).unsqueeze(0)
55
+ if torch.cuda.is_available():
56
+ audio_tensor = audio_tensor.to("cuda")
57
+ codes = rvq.encode(audio_tensor)
58
+ decoded = rvq.decode(codes)
59
+ array = decoded.squeeze().cpu().numpy()
60
+ # Ultravox ASR→LLM
61
+ ultra_out = ultravox_pipe({"array": array, "sampling_rate": sr})
62
+ text = ultra_out["text"]
63
+ # Diffusion-based prosody enhancement
64
+ prosody_audio = diff_pipe({"array": decoded.cpu().numpy(), "sampling_rate": sr})["array"][0]
65
+ # Dia TTS
66
+ tts_audio = dia.generate(f"[emotion:neutral] {text}")
67
+ tts_np = tts_audio.squeeze().cpu().numpy()
68
+ # Normalize
69
+ tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
70
+ return (sr, tts_np), text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ # 2.8: Gradio UI
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("## Supernatural Speech AI Agent")
75
+ audio_in = gr.Audio(source="microphone", type="numpy", label="Record Your Voice")
76
+ btn = gr.Button("Send")
77
+ audio_out = gr.Audio(label="AI Response")
78
+ txt_out = gr.Textbox(label="Transcribed & Generated Text")
79
+ btn.click(fn=process_audio, inputs=audio_in, outputs=[audio_out, txt_out])
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  if __name__ == "__main__":
82
+ demo.launch()