File size: 2,707 Bytes
6e55da8
c5ef34e
653911d
 
6e55da8
ee439d6
036f56f
1a24747
 
 
5adc99b
ee439d6
6e55da8
42e6e01
ee439d6
42e6e01
5adc99b
ee439d6
1a24747
 
ee439d6
 
036f56f
 
 
 
 
42e6e01
1a24747
ee439d6
1a24747
 
 
 
 
 
 
ee439d6
 
 
1a24747
ee439d6
1a24747
ee439d6
1a24747
42e6e01
 
 
 
 
1a24747
42e6e01
 
 
 
1a24747
 
ee439d6
653911d
6e55da8
ee439d6
6e55da8
ee439d6
 
6e55da8
 
ee439d6
 
6e55da8
 
ee439d6
 
 
6e55da8
ee439d6
 
6e55da8
ee439d6
 
 
6e55da8
ee439d6
6e55da8
ee439d6
 
6e55da8
ee439d6
 
 
 
 
653911d
 
1a24747
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
import gradio as gr
import torch
import numpy as np
from transformers import pipeline
from diffusers import DiffusionPipeline
from pyannote.audio import Pipeline as PyannotePipeline
from dia.model import Dia
from dac.utils import load_model as load_dac_model
from accelerate import init_empty_weights, load_checkpoint_and_dispatch

# Retrieve HF token from Secrets
HF_TOKEN = os.environ["HF_TOKEN"]

# Automatic multi-GPU sharding across 4× L4 GPUs
device_map = "auto"

# 1. Descript Audio Codec (RVQ)
rvq = load_dac_model(tag="latest", model_type="44khz")
rvq.eval()
if torch.cuda.is_available():
    rvq = rvq.to("cuda")

# 2. Voice Activity Detection via Pyannote
vad_pipe = PyannotePipeline.from_pretrained(
    "pyannote/voice-activity-detection",
    use_auth_token=HF_TOKEN
)

# 3. Ultravox ASR+LLM
ultravox_pipe = pipeline(
    model="fixie-ai/ultravox-v0_4",
    trust_remote_code=True,
    device_map=device_map,
    torch_dtype=torch.float16
)

# 4. Audio Diffusion (direct load via Diffusers)
diff_pipe = DiffusionPipeline.from_pretrained(
    "teticio/audio-diffusion-instrumental-hiphop-256",
    torch_dtype=torch.float16
).to("cuda")

# 5. Dia TTS (multi-GPU dispatch)
with init_empty_weights():
    dia = Dia.from_pretrained(
        "nari-labs/Dia-1.6B",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
dia = load_checkpoint_and_dispatch(
    dia,
    "nari-labs/Dia-1.6B",
    device_map=device_map,
    dtype=torch.float16
)

# 6. Inference Function
def process_audio(audio):
    sr, array = audio
    array = array.numpy() if torch.is_tensor(array) else array

    # VAD
    _ = vad_pipe(array, sampling_rate=sr)

    # RVQ encode/decode
    tensor = torch.tensor(array).unsqueeze(0).to("cuda")
    codes = rvq.encode(tensor)
    decoded = rvq.decode(codes).squeeze().cpu().numpy()

    # Ultravox inference
    ultra_out = ultravox_pipe({"array": decoded, "sampling_rate": sr})
    text = ultra_out.get("text", "")

    # Diffusion enhancement
    pros = diff_pipe(raw_audio=decoded)["audios"][0]

    # Dia TTS
    tts = dia.generate(f"[emotion:neutral] {text}").squeeze().cpu().numpy()
    tts = tts / np.max(np.abs(tts)) * 0.95

    return (sr, tts), text

# 7. Gradio UI
with gr.Blocks(title="Maya AI 📈") as demo:
    gr.Markdown("## Maya-AI: Supernatural Conversational Agent")
    audio_input = gr.Audio(source="microphone", type="numpy", label="Your Voice")
    send_button = gr.Button("Send")
    audio_output = gr.Audio(label="AI’s Response")
    text_output = gr.Textbox(label="Generated Text")
    send_button.click(process_audio, inputs=audio_input, outputs=[audio_output, text_output])

if __name__ == "__main__":
    demo.launch()