Maya-AI / app.py
Devakumar868's picture
Update app.py
5adc99b verified
raw
history blame
2.19 kB
import gradio as gr
from transformers import AutoProcessor, CsmForConditionalGeneration
from dia.model import Dia
from pyannote.audio import Pipeline as VAD
import torch, numpy as np
# Load models
ultra_proc = AutoProcessor.from_pretrained("fixie-ai/ultravox-v0_4")
ultra_model = CsmForConditionalGeneration.from_pretrained("fixie-ai/ultravox-v0_4", device_map="auto", torch_dtype=torch.float16)
ser = AutoProcessor.from_pretrained("r-f/wav2vec-english-speech-emotion-recognition")
ser_model = torch.hub.load("jonatasgrosman/wav2vec2-large-xlsr-53-english", "wav2vec2_large_xlsr", pretrained=True).to("cuda")
diff_pipe = torch.hub.load("teticio/audio-diffusion-instrumental-hiphop-256", "audio_diffusion").to("cuda")
rvq = torch.hub.load("ibm/DAC.speech.v1.0", "DAC_speech_v1_0").to("cuda")
vad = VAD.from_pretrained("pyannote/voice-activity-detection")
dia = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16")
def process(audio):
# VAD
speech = vad({"waveform": audio["array"], "sample_rate": audio["sampling_rate"]})
# RVQ encode/decode
codes = rvq.encode(audio["array"])
dec_audio = rvq.decode(codes)
# Emotion
emo_inputs = ser(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt")
emotion = ser_model(**emo_inputs).logits.argmax(-1).item()
# Ultravox generation
inputs = ultra_proc(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").to("cuda")
speech_out = ultra_model.generate(**inputs, output_audio=True)
# Diffuse and clone voice
audio_diff = diff_pipe(speech_out.audio).audios[0]
# TTS
text = f"[S1][emotion={emotion}]" + " ".join(["..."]) # placeholder
dia_audio = dia.generate(text)
# Normalize
dia_audio = dia_audio / np.max(np.abs(dia_audio)) * 0.95
return 44100, dia_audio
with gr.Blocks() as demo:
state = gr.State([])
audio_in = gr.Audio(source="microphone", type="numpy")
chat = gr.Chatbot()
record = gr.Button("Record")
record.click(process, inputs=audio_in, outputs=[audio_in]).then(
lambda a: chat.update(value=[("User", ""), ("AI", "")]),
)
demo.queue(concurrency_limit=20, max_size=50).launch()