Devakumar868 commited on
Commit
c5ef34e
·
verified ·
1 Parent(s): ad69d0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -0
app.py CHANGED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile
2
+ import gradio as gr
3
+ from nemo.collections.asr import EncDecRNNTBPEModel
4
+ from speechbrain.pretrained import EncoderClassifier
5
+ from transformers import DiffusionPipeline, AutoModelForCausalLM, AutoTokenizer
6
+ from dia.model import Dia
7
+ import soundfile as sf
8
+ # Load models
9
+ asr = EncDecRNNTBPEModel.from_pretrained("nvidia/parakeet-tdt-0.6b-v2")
10
+ emotion = EncoderClassifier.from_hparams(source="speechbrain/emotion-recognition-wav2vec2-IEMOCAP")
11
+ diffuser = DiffusionPipeline.from_pretrained("teticio/audio-diffusion-256").to("cuda")
12
+ llm_tokenizer = AutoTokenizer.from_pretrained("Vicuna-7B")
13
+ llm = AutoModelForCausalLM.from_pretrained("Vicuna-7B").half().to("cuda")
14
+ tts = Dia.from_pretrained("nari-labs/Dia-1.6B")
15
+
16
+ app = FastAPI()
17
+ def process(audio_file):
18
+ # Save
19
+ data, sr = sf.read(audio_file)
20
+ # ASR
21
+ text = asr.transcribe([audio_file])[0]
22
+ # Emotion
23
+ emo = emotion.classify_file(audio_file)["label"]
24
+ # LLM response
25
+ inputs = llm_tokenizer(text, return_tensors="pt").to("cuda")
26
+ resp = llm.generate(**inputs, max_new_tokens=128)
27
+ reply = llm_tokenizer.decode(resp[0])
28
+ # TTS
29
+ wav = tts.generate(f"[S1] {reply} [S2]")
30
+ sf.write("reply.wav", wav, 44100)
31
+ return text, emo, reply, "reply.wav"
32
+
33
+ # Gradio UI
34
+ iface = gr.Interface(fn=process, inputs=gr.Audio(source="microphone"), outputs=[
35
+ gr.Textbox(label="Transcript"),
36
+ gr.Textbox(label="Emotion"),
37
+ gr.Textbox(label="Reply"),
38
+ gr.Audio(label="Audio Reply")
39
+ ], live=False, enable_queue=True)
40
+ app.mount("/", gr.routes.App.create_app(iface))
41
+ if __name__=="__main__":
42
+ import uvicorn; uvicorn.run(app, host="0.0.0.0", port=7860)