Maya-AI / app.py
Devakumar868's picture
Update app.py
42e6e01 verified
raw
history blame
4.14 kB
import gradio as gr
import torch
import numpy as np
import os
from transformers import pipeline, AutoProcessor, CsmForConditionalGeneration
from pyannote.audio import Model, Inference
from dia.model import Dia
from dac.utils import load_model as load_dac_model
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
# Access HF_TOKEN from environment variables (Secrets)
HF_TOKEN = os.environ.get("HF_TOKEN")
# Device mapping for 4× L4 GPU distribution
device_map = "auto"
print("Loading models...")
# Load Descript Audio Codec (RVQ) at startup
print("Loading RVQ Codec...")
rvq = load_dac_model(tag="latest", model_type="44khz")
rvq.eval()
if torch.cuda.is_available():
rvq = rvq.to("cuda")
# Load segmentation model with authentication
print("Loading Segmentation Model...")
seg_model = Model.from_pretrained(
"pyannote/segmentation",
use_auth_token=HF_TOKEN
)
seg_inference = Inference(seg_model, device=0 if torch.cuda.is_available() else -1)
# Use segmentation model for VAD
vad = seg_inference
# Load Ultravox via generic pipeline (without specifying task)
print("Loading Ultravox...")
ultravox_pipe = pipeline(
model="fixie-ai/ultravox-v0_4",
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch.float16
)
# Load Diffusion model
print("Loading Diffusion Model...")
diff_pipe = pipeline(
"audio-to-audio",
model="teticio/audio-diffusion-instrumental-hiphop-256",
trust_remote_code=True,
device_map=device_map,
torch_dtype=torch.float16
)
# Load Dia TTS with multi-GPU dispatch
print("Loading Dia TTS...")
with init_empty_weights():
dia = Dia.from_pretrained(
"nari-labs/Dia-1.6B",
torch_dtype=torch.float16,
trust_remote_code=True
)
dia = load_checkpoint_and_dispatch(
dia,
"nari-labs/Dia-1.6B",
device_map=device_map,
dtype=torch.float16
)
print("All models loaded successfully!")
# Gradio inference function
def process_audio(audio):
try:
if audio is None:
return None, "No audio input provided"
sr, array = audio
# Ensure audio is numpy array
if torch.is_tensor(array):
array = array.numpy()
# VAD segmentation
segments = vad({"waveform": torch.tensor(array).unsqueeze(0), "sample_rate": sr})
# RVQ encode/decode
audio_tensor = torch.tensor(array).unsqueeze(0)
if torch.cuda.is_available():
audio_tensor = audio_tensor.to("cuda")
codes = rvq.encode(audio_tensor)
decoded = rvq.decode(codes)
array = decoded.squeeze().cpu().numpy()
# Ultravox ASR→LLM
ultra_out = ultravox_pipe({"array": array, "sampling_rate": sr})
text = ultra_out.get("text", "I understand your audio input.")
# Diffusion-based prosody enhancement
prosody_audio = diff_pipe({"array": decoded.cpu().numpy(), "sampling_rate": sr})["array"][0]
# Dia TTS
tts_audio = dia.generate(f"[emotion:neutral] {text}")
tts_np = tts_audio.squeeze().cpu().numpy()
# Normalize
tts_np = tts_np / np.max(np.abs(tts_np)) * 0.95
return (sr, tts_np), text
except Exception as e:
print(f"Error in process_audio: {e}")
return None, f"Processing error: {str(e)}"
# Gradio UI
with gr.Blocks(title="Maya-AI: Supernatural Speech Agent") as demo:
gr.Markdown("# Maya-AI: Supernatural Speech Agent")
gr.Markdown("Record audio to interact with the AI agent that understands emotions and responds naturally.")
with gr.Row():
with gr.Column():
audio_in = gr.Audio(source="microphone", type="numpy", label="Record Your Voice")
btn = gr.Button("Send", variant="primary")
with gr.Column():
audio_out = gr.Audio(label="AI Response")
txt_out = gr.Textbox(label="Transcribed & Generated Text", lines=3)
btn.click(fn=process_audio, inputs=audio_in, outputs=[audio_out, txt_out])
if __name__ == "__main__":
demo.launch()