Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, UploadFile
|
2 |
+
import gradio as gr
|
3 |
+
from nemo.collections.asr import EncDecRNNTBPEModel
|
4 |
+
from speechbrain.pretrained import EncoderClassifier
|
5 |
+
from transformers import DiffusionPipeline, AutoModelForCausalLM, AutoTokenizer
|
6 |
+
from dia.model import Dia
|
7 |
+
import soundfile as sf
|
8 |
+
# Load models
|
9 |
+
asr = EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
|
10 |
+
emotion = EncoderClassifier.from_hparams(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP")
|
11 |
+
diffuser = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to("cuda")
|
12 |
+
llm_tokenizer = AutoTokenizer.from_pretrained("Vicuna-7B")
|
13 |
+
llm = AutoModelForCausalLM.from_pretrained("Vicuna-7B").half().to("cuda")
|
14 |
+
tts = Dia.from_pretrained("nari-labs/Dia-1.6B")
|
15 |
+
|
16 |
+
app = FastAPI()
|
17 |
+
def process(audio_file):
|
18 |
+
# Save
|
19 |
+
data, sr = sf.read(audio_file)
|
20 |
+
# ASR
|
21 |
+
text = asr.transcribe([audio_file])[0]
|
22 |
+
# Emotion
|
23 |
+
emo = emotion.classify_file(audio_file)["label"]
|
24 |
+
# LLM response
|
25 |
+
inputs = llm_tokenizer(text, return_tensors="pt").to("cuda")
|
26 |
+
resp = llm.generate(**inputs, max_new_tokens=128)
|
27 |
+
reply = llm_tokenizer.decode(resp[0])
|
28 |
+
# TTS
|
29 |
+
wav = tts.generate(f"[S1] {reply} [S2]")
|
30 |
+
sf.write("reply.wav", wav, 44100)
|
31 |
+
return text, emo, reply, "reply.wav"
|
32 |
+
|
33 |
+
# Gradio UI
|
34 |
+
iface = gr.Interface(fn=process, inputs=gr.Audio(source="microphone"), outputs=[
|
35 |
+
gr.Textbox(label="Transcript"),
|
36 |
+
gr.Textbox(label="Emotion"),
|
37 |
+
gr.Textbox(label="Reply"),
|
38 |
+
gr.Audio(label="Audio Reply")
|
39 |
+
], live=False, enable_queue=True)
|
40 |
+
app.mount("/", gr.routes.App.create_app(iface))
|
41 |
+
if __name__=="__main__":
|
42 |
+
import uvicorn; uvicorn.run(app, host="0.0.0.0", port=7860)
|