File size: 4,135 Bytes
c5ef34e
653911d
 
42e6e01
1a24747
42e6e01
1a24747
 
 
5adc99b
42e6e01
 
 
 
 
5adc99b
42e6e01
 
 
 
1a24747
 
 
 
5adc99b
42e6e01
 
 
 
 
 
 
1a24747
42e6e01
 
 
 
 
1a24747
 
 
 
 
 
 
42e6e01
 
1a24747
 
 
 
 
 
 
 
42e6e01
 
1a24747
42e6e01
 
 
 
 
1a24747
42e6e01
 
 
 
1a24747
 
42e6e01
 
 
653911d
42e6e01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653911d
42e6e01
 
 
 
 
 
 
 
 
 
 
 
 
 
1a24747
653911d
 
1a24747
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
import torch
import numpy as np
import os
from transformers import pipeline, AutoProcessor, CsmForConditionalGeneration
from pyannote.audio import Model, Inference
from dia.model import Dia
from dac.utils import load_model as load_dac_model
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

# Access HF_TOKEN from environment variables (Secrets)
HF_TOKEN = os.environ.get("HF_TOKEN")

# Device mapping for 4× L4 GPU distribution
device_map = "auto"

print("Loading models...")

# Load Descript Audio Codec (RVQ) at startup
print("Loading RVQ Codec...")
rvq = load_dac_model(tag="latest", model_type="44khz")
rvq.eval()
if torch.cuda.is_available():
    rvq = rvq.to("cuda")

# Load segmentation model with authentication
print("Loading Segmentation Model...")
seg_model = Model.from_pretrained(
    "pyannote/segmentation",
    use_auth_token=HF_TOKEN
)
seg_inference = Inference(seg_model, device=0 if torch.cuda.is_available() else -1)

# Use segmentation model for VAD
vad = seg_inference

# Load Ultravox via generic pipeline (without specifying task)
print("Loading Ultravox...")
ultravox_pipe = pipeline(
    model="fixie-ai/ultravox-v0_4",
    trust_remote_code=True,
    device_map=device_map,
    torch_dtype=torch.float16
)

# Load Diffusion model
print("Loading Diffusion Model...")
diff_pipe = pipeline(
    "audio-to-audio",
    model="teticio/audio-diffusion-instrumental-hiphop-256",
    trust_remote_code=True,
    device_map=device_map,
    torch_dtype=torch.float16
)

# Load Dia TTS with multi-GPU dispatch
print("Loading Dia TTS...")
with init_empty_weights():
    dia = Dia.from_pretrained(
        "nari-labs/Dia-1.6B",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
dia = load_checkpoint_and_dispatch(
    dia,
    "nari-labs/Dia-1.6B",
    device_map=device_map,
    dtype=torch.float16
)

print("All models loaded successfully!")

# Gradio inference function
def process_audio(audio):
    try:
        if audio is None:
            return None, "No audio input provided"
        
        sr, array = audio
        
        # Ensure audio is numpy array
        if torch.is_tensor(array):
            array = array.numpy()
        
        # VAD segmentation
        segments = vad({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
        
        # RVQ encode/decode
        audio_tensor = torch.tensor(array).unsqueeze(0)
        if torch.cuda.is_available():
            audio_tensor = audio_tensor.to("cuda")
        codes = rvq.encode(audio_tensor)
        decoded = rvq.decode(codes)
        array = decoded.squeeze().cpu().numpy()
        
        # Ultravox ASR→LLM
        ultra_out = ultravox_pipe({"array": array, "sampling_rate": sr})
        text = ultra_out.get("text", "I understand your audio input.")
        
        # Diffusion-based prosody enhancement
        prosody_audio = diff_pipe({"array": decoded.cpu().numpy(), "sampling_rate": sr})["array"][0]
        
        # Dia TTS
        tts_audio = dia.generate(f"[emotion:neutral] {text}")
        tts_np = tts_audio.squeeze().cpu().numpy()
        
        # Normalize
        tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
        
        return (sr, tts_np), text
        
    except Exception as e:
        print(f"Error in process_audio: {e}")
        return None, f"Processing error: {str(e)}"

# Gradio UI
with gr.Blocks(title="Maya-AI: Supernatural Speech Agent") as demo:
    gr.Markdown("# Maya-AI: Supernatural Speech Agent")
    gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
    
    with gr.Row():
        with gr.Column():
            audio_in = gr.Audio(source="microphone", type="numpy", label="Record Your Voice")
            btn = gr.Button("Send", variant="primary")
        
        with gr.Column():
            audio_out = gr.Audio(label="AI Response")
            txt_out = gr.Textbox(label="Transcribed & Generated Text", lines=3)
    
    btn.click(fn=process_audio, inputs=audio_in, outputs=[audio_out, txt_out])

if __name__ == "__main__":
    demo.launch()