import gradio as gr from transformers import AutoProcessor, CsmForConditionalGeneration from dia.model import Dia from pyannote.audio import Pipeline as VAD import torch, numpy as np # Load models ultra_proc = AutoProcessor.from_pretrained("fixie-ai/ultravox-v0_4") ultra_model = CsmForConditionalGeneration.from_pretrained("fixie-ai/ultravox-v0_4", device_map="auto", torch_dtype=torch.float16) ser = AutoProcessor.from_pretrained("r-f/wav2vec-english-speech-emotion-recognition") ser_model = torch.hub.load("jonatasgrosman/wav2vec2-large-xlsr-53-english", "wav2vec2_large_xlsr", pretrained=True).to("cuda") diff_pipe = torch.hub.load("teticio/audio-diffusion-instrumental-hiphop-256", "audio_diffusion").to("cuda") rvq = torch.hub.load("ibm/DAC.speech.v1.0", "DAC_speech_v1_0").to("cuda") vad = VAD.from_pretrained("pyannote/voice-activity-detection") dia = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16") def process(audio): # VAD speech = vad({"waveform": audio["array"], "sample_rate": audio["sampling_rate"]}) # RVQ encode/decode codes = rvq.encode(audio["array"]) dec_audio = rvq.decode(codes) # Emotion emo_inputs = ser(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt") emotion = ser_model(**emo_inputs).logits.argmax(-1).item() # Ultravox generation inputs = ultra_proc(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").to("cuda") speech_out = ultra_model.generate(**inputs, output_audio=True) # Diffuse and clone voice audio_diff = diff_pipe(speech_out.audio).audios[0] # TTS text = f"[S1][emotion={emotion}]" + " ".join(["..."]) # placeholder dia_audio = dia.generate(text) # Normalize dia_audio = dia_audio / np.max(np.abs(dia_audio)) * 0.95 return 44100, dia_audio with gr.Blocks() as demo: state = gr.State([]) audio_in = gr.Audio(source="microphone", type="numpy") chat = gr.Chatbot() record = gr.Button("Record") record.click(process, inputs=audio_in, outputs=[audio_in]).then( lambda a: chat.update(value=[("User", ""), ("AI", "")]), ) demo.queue(concurrency_limit=20, max_size=50).launch()