Devakumar868 commited on
Commit
653911d
·
verified ·
1 Parent(s): 72d9597

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -39
app.py CHANGED
@@ -1,46 +1,142 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, CsmForConditionalGeneration
3
- from dia.model import Dia
 
4
  from pyannote.audio import Pipeline as VAD
5
- import torch, numpy as np
6
 
7
- # Load models
8
- ultra_proc = AutoProcessor.from_pretrained("fixie-ai/ultravox-v0_4")
9
- ultra_model = CsmForConditionalGeneration.from_pretrained("fixie-ai/ultravox-v0_4", device_map="auto", torch_dtype=torch.float16)
10
- ser = AutoProcessor.from_pretrained("r-f/wav2vec-english-speech-emotion-recognition")
11
- ser_model = torch.hub.load("jonatasgrosman/wav2vec2-large-xlsr-53-english", "wav2vec2_large_xlsr", pretrained=True).to("cuda")
12
- diff_pipe = torch.hub.load("teticio/audio-diffusion-instrumental-hiphop-256", "audio_diffusion").to("cuda")
13
- rvq = torch.hub.load("ibm/DAC.speech.v1.0", "DAC_speech_v1_0").to("cuda")
14
- vad = VAD.from_pretrained("pyannote/voice-activity-detection")
15
- dia = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- def process(audio):
18
- # VAD
19
- speech = vad({"waveform": audio["array"], "sample_rate": audio["sampling_rate"]})
20
- # RVQ encode/decode
21
- codes = rvq.encode(audio["array"])
22
- dec_audio = rvq.decode(codes)
23
- # Emotion
24
- emo_inputs = ser(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
25
- emotion = ser_model(**emo_inputs).logits.argmax(-1).item()
26
- # Ultravox generation
27
- inputs = ultra_proc(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").to("cuda")
28
- speech_out = ultra_model.generate(**inputs, output_audio=True)
29
- # Diffuse and clone voice
30
- audio_diff = diff_pipe(speech_out.audio).audios[0]
31
- # TTS
32
- text = f"[S1][emotion={emotion}]" + " ".join(["..."]) # placeholder
33
- dia_audio = dia.generate(text)
34
- # Normalize
35
- dia_audio = dia_audio / np.max(np.abs(dia_audio)) * 0.95
36
- return 44100, dia_audio
37
 
38
- with gr.Blocks() as demo:
39
- state = gr.State([])
40
- audio_in = gr.Audio(source="microphone", type="numpy")
41
- chat = gr.Chatbot()
42
- record = gr.Button("Record")
43
- record.click(process, inputs=audio_in, outputs=[audio_in]).then(
44
- lambda a: chat.update(value=[("User", ""), ("AI", "")]),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  )
 
 
46
  demo.queue(concurrency_limit=20, max_size=50).launch()
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
3
+ import torch
4
+ import numpy as np
5
  from pyannote.audio import Pipeline as VAD
6
+ import dac
7
 
8
+ # Load models with proper error handling
9
+ def load_models():
10
+ try:
11
+ # Ultravox via transformers (no separate package needed)
12
+ ultra_proc = AutoProcessor.from_pretrained("fixie-ai/ultravox-v0_4", trust_remote_code=True)
13
+ ultra_model = AutoModelForCausalLM.from_pretrained("fixie-ai/ultravox-v0_4", device_map="auto", torch_dtype=torch.float16, trust_remote_code=True)
14
+
15
+ # Speech emotion recognition via transformers pipeline
16
+ emotion_pipeline = pipeline("audio-classification", model="ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", device=0 if torch.cuda.is_available() else -1)
17
+
18
+ # Audio diffusion (using transformers instead of torch.hub for HF compatibility)
19
+ from diffusers import DiffusionPipeline
20
+ diff_pipe = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-instrumental-hiphop-256")
21
+
22
+ # Descript Audio Codec
23
+ from dac.utils import load_model as load_dac_model
24
+ rvq = load_dac_model(tag="latest", model_type="44khz")
25
+ rvq.eval()
26
+ if torch.cuda.is_available():
27
+ rvq = rvq.to("cuda")
28
+
29
+ # VAD
30
+ vad = VAD.from_pretrained("pyannote/voice-activity-detection")
31
+
32
+ # Dia TTS
33
+ from dia.model import Dia
34
+ dia = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
35
+
36
+ return ultra_proc, ultra_model, emotion_pipeline, diff_pipe, rvq, vad, dia
37
+
38
+ except Exception as e:
39
+ print(f"Error loading models: {e}")
40
+ return None, None, None, None, None, None, None
41
 
42
+ # Initialize models
43
+ ultra_proc, ultra_model, emotion_pipeline, diff_pipe, rvq, vad, dia = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def process_audio(audio):
46
+ try:
47
+ if audio is None:
48
+ return None, "No audio input provided"
49
+
50
+ # Convert audio to proper format
51
+ audio_array = audio[1] if isinstance(audio, tuple) else audio["array"]
52
+ sample_rate = audio[0] if isinstance(audio, tuple) else audio["sampling_rate"]
53
+
54
+ # Ensure audio is numpy array
55
+ if torch.is_tensor(audio_array):
56
+ audio_array = audio_array.numpy()
57
+
58
+ # VAD processing
59
+ if vad is not None:
60
+ speech_segments = vad({"waveform": torch.from_numpy(audio_array).unsqueeze(0), "sample_rate": sample_rate})
61
+
62
+ # Emotion recognition
63
+ emotion_result = "neutral"
64
+ if emotion_pipeline is not None:
65
+ try:
66
+ emotion_pred = emotion_pipeline(audio_array, sampling_rate=sample_rate)
67
+ emotion_result = emotion_pred[0]["label"] if emotion_pred else "neutral"
68
+ except:
69
+ emotion_result = "neutral"
70
+
71
+ # RVQ encode/decode
72
+ if rvq is not None:
73
+ try:
74
+ audio_tensor = torch.from_numpy(audio_array).float().unsqueeze(0)
75
+ if torch.cuda.is_available():
76
+ audio_tensor = audio_tensor.to("cuda")
77
+ encoded = rvq.encode(audio_tensor)
78
+ decoded_audio = rvq.decode(encoded)
79
+ if torch.cuda.is_available():
80
+ decoded_audio = decoded_audio.cpu()
81
+ audio_array = decoded_audio.squeeze().numpy()
82
+ except Exception as e:
83
+ print(f"RVQ processing error: {e}")
84
+
85
+ # Ultravox generation
86
+ response_text = "I understand your audio input."
87
+ if ultra_proc is not None and ultra_model is not None:
88
+ try:
89
+ inputs = ultra_proc(audio_array, sampling_rate=sample_rate, return_tensors="pt")
90
+ if torch.cuda.is_available():
91
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
92
+
93
+ with torch.no_grad():
94
+ outputs = ultra_model.generate(**inputs, max_new_tokens=50)
95
+ response_text = ultra_proc.decode(outputs[0], skip_special_tokens=True)
96
+ except Exception as e:
97
+ print(f"Ultravox generation error: {e}")
98
+ response_text = f"Detected emotion: {emotion_result}"
99
+
100
+ # TTS generation
101
+ output_audio = None
102
+ if dia is not None:
103
+ try:
104
+ tts_text = f"[emotion:{emotion_result}] {response_text}"
105
+ output_audio = dia.generate(tts_text)
106
+ if torch.is_tensor(output_audio):
107
+ output_audio = output_audio.cpu().numpy()
108
+ # Normalize audio
109
+ if output_audio is not None:
110
+ output_audio = output_audio / np.max(np.abs(output_audio)) * 0.95
111
+ except Exception as e:
112
+ print(f"TTS generation error: {e}")
113
+
114
+ return (sample_rate, output_audio) if output_audio is not None else None, response_text
115
+
116
+ except Exception as e:
117
+ return None, f"Processing error: {str(e)}"
118
+
119
+ # Create Gradio interface
120
+ with gr.Blocks(title="Supernatural Speech AI") as demo:
121
+ gr.Markdown("# Supernatural Speech AI Agent")
122
+ gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
123
+
124
+ with gr.Row():
125
+ with gr.Column():
126
+ audio_input = gr.Audio(source="microphone", type="numpy", label="Record Audio")
127
+ process_btn = gr.Button("Process Audio", variant="primary")
128
+
129
+ with gr.Column():
130
+ audio_output = gr.Audio(label="AI Response")
131
+ text_output = gr.Textbox(label="Response Text", lines=3)
132
+
133
+ conversation_history = gr.State([])
134
+
135
+ process_btn.click(
136
+ fn=process_audio,
137
+ inputs=[audio_input],
138
+ outputs=[audio_output, text_output]
139
  )
140
+
141
+ if __name__ == "__main__":
142
  demo.queue(concurrency_limit=20, max_size=50).launch()