Spaces:
Paused
Paused
| import plotly.graph_objs as go | |
| from sklearn.cluster import KMeans | |
| from sklearn.decomposition import PCA | |
| import plotly.express as px | |
| import numpy as np | |
| import os | |
| import pprint | |
| import codecs | |
| import chardet | |
| import gradio as gr | |
| from langchain.llms import HuggingFacePipeline | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from EdgeGPT import Chatbot | |
| import whisper | |
| from datetime import datetime | |
| import json | |
| import requests | |
| class ChatbotClass: | |
| def __init__(self): | |
| FOLDER_PATH = './data/eqe-manual' | |
| QUERY = 'How do I charge my vehicle?' | |
| K = 10 | |
| self.whisper_model = whisper.load_model(name='tiny') | |
| self.embeddings = HuggingFaceEmbeddings() | |
| self.index = FAISS.load_local( | |
| folder_path=FOLDER_PATH, embeddings=self.embeddings | |
| ) | |
| self.llm = OpenAIChat(temperature=0) | |
| self.memory = ConversationBufferMemory( | |
| memory_key="chat_history", input_key="human_input", return_messages=True | |
| ) | |
| self.keyword_chain = self.init_keyword_chain() | |
| self.context_chain = self.init_context_chain() | |
| self.document_retrieval_chain = self.init_document_retrieval() | |
| self.conversation_chain = self.init_conversation() | |
| def format_history(self, memory): | |
| history = memory.chat_memory.messages | |
| if len(history) == 0: | |
| return [] | |
| formatted_history = [] | |
| for h in history: | |
| if isinstance(h, langchain.schema.HumanMessage): | |
| user_response = h.content | |
| elif isinstance(h, langchain.schema.AIMessage): | |
| ai_response = h.content | |
| formatted_history.append((user_response, ai_response)) | |
| return formatted_history | |
| def init_document_retrieval(self): | |
| retrieve_documents_template = """This function retrieves exerts from a Vehicle Owner's Manual. The function is useful for adding vehicle-specific context to answer questions. Based on a request, determine if vehicle specific information is needed. Respond with "Yes" or "No". If the answer is both, respond with "Yes":\nrequest: How do I change the tire?\nresponse: Yes\nrequest: Hello\nresponse: No\nrequest: I was in an accident. What should I do?\nresponse: Yes\nrequest: {request}\nresponse:""" | |
| prompt_template = PromptTemplate( | |
| input_variables=["request"], | |
| template=retrieve_documents_template | |
| ) | |
| document_retrieval_chain = LLMChain( | |
| llm=self.llm, prompt=prompt_template, verbose=True | |
| ) | |
| return document_retrieval_chain | |
| def init_keyword_chain(self): | |
| keyword_template = """You are a vehicle owner searching for content in your vehicle's owner manual. Your job is to come up with keywords to use when searching inside your manual, based on a question you have. | |
| Question: {question} | |
| Keywords:""" | |
| prompt_template = PromptTemplate( | |
| input_variables=["question"], template=keyword_template | |
| ) | |
| keyword_chain = LLMChain( | |
| llm=self.llm, prompt=prompt_template, verbose=True) | |
| return keyword_chain | |
| def init_context_chain(self): | |
| context_template = """You are a friendly and helpful chatbot having a conversation with a human. | |
| Given the following extracted parts of a long document and a question, create a final answer. | |
| {context} | |
| {chat_history} | |
| Human: {human_input} | |
| Chatbot:""" | |
| context_prompt = PromptTemplate( | |
| input_variables=["chat_history", "human_input", "context"], | |
| template=context_template | |
| ) | |
| self.memory = ConversationBufferMemory( | |
| memory_key="chat_history", input_key="human_input", return_messages=True | |
| ) | |
| context_chain = load_qa_chain( | |
| self.llm, chain_type="stuff", memory=self.memory, prompt=context_prompt | |
| ) | |
| return context_chain | |
| def init_conversation(self): | |
| template = """You are a chatbot having a conversation with a human. | |
| {chat_history} | |
| Human: {human_input} | |
| Chatbot:""" | |
| prompt = PromptTemplate( | |
| input_variables=["chat_history", "human_input"], | |
| template=template | |
| ) | |
| conversation_chain = LLMChain( | |
| llm=self.llm, | |
| prompt=prompt, | |
| verbose=True, | |
| memory=self.memory, | |
| ) | |
| return conversation_chain | |
| def transcribe_audio(self, audio_file, model): | |
| result = self.whisper_model.transcribe(audio_file) | |
| return result['text'] | |
| def ask_question(self, query, k=4): | |
| tool_usage = self.document_retrieval_chain.run(query) | |
| print('\033[1;32m' f'search manual?: {tool_usage}' "\033[0m") | |
| chat_history = self.format_history(self.memory) | |
| if tool_usage == 'Yes': | |
| keywords = self.keyword_chain.run(question=query) | |
| print('\033[1;32m' f'keywords:{keywords}' "\033[0m") | |
| context = self.index.similarity_search(query=keywords, k=k) | |
| result = self.context_chain.run( | |
| input_documents=context, human_input=query | |
| ) | |
| else: | |
| result = self.conversation_chain.run(query) | |
| return [(query, result)], chat_history | |
| def invoke_exh_api(self, bot_response, bot_name='Zippy', voice_name='Fiona', idle_url='https://ugc-idle.s3-us-west-2.amazonaws.com/4a6a607a466bdf6605bbd97ef146751b.mp4', animation_pipeline='high_quality', bearer_token='eyJhbGciOiJIUzUxMiJ9.eyJ1c2VybmFtZSI6IndlYiJ9.LSzIQx6h61l5FXs52s0qcY8WqauET6z9nnxgSzvoNBx8RYEKm8OpOohcK8wjuwteV4ZGug4NOjoGQoUZIKH84A'): | |
| if len(bot_response) > 200: | |
| print('Input is over 200 characters. Shorten the message') | |
| url = 'https://api.exh.ai/animations/v1/generate_lipsync' | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.46', | |
| 'authority': 'api.exh.ai', | |
| 'accept': '*/*', | |
| 'accept-encoding': 'gzip, deflate, br', | |
| 'accept-language': 'en-US,en;q=0.9', | |
| 'authorization': f'Bearer {bearer_token}', | |
| 'content-type': 'application/json', | |
| 'origin': 'https://admin.exh.ai', | |
| 'referer': 'https://admin.exh.ai/', | |
| 'sec-ch-ua': '"Chromium";v="110", "Not A(Brand";v="24", "Microsoft Edge";v="110"', | |
| 'sec-ch-ua-mobile': '?0', | |
| 'sec-ch-ua-platform': '"Windows"', | |
| 'sec-fetch-dest': 'empty', | |
| 'sec-fetch-mode': 'cors', | |
| 'sec-fetch-site': 'same-site', | |
| } | |
| data = { | |
| 'bot_name': bot_name, | |
| 'bot_response': bot_response, | |
| 'voice_name': voice_name, | |
| 'idle_url': idle_url, | |
| 'animation_pipeline': animation_pipeline, | |
| } | |
| r = requests.post(url, headers=headers, data=json.dumps(data)) | |
| timestamp = datetime.now().strftime('%Y%m%d_%H%M%S%f') | |
| outfile = f'talking_head_{timestamp}.mp4' | |
| with open(outfile, 'wb') as f: | |
| f.write(r.content) | |
| return outfile | |
| def predict(self, input_data, state=[], k=4, input_type='audio'): | |
| if input_type == 'audio': | |
| txt = self.transcribe_audio(input_data[0], self.whisper_model) | |
| else: | |
| txt = input_data[1] | |
| result, chat_history = self.ask_question(txt, k=k) | |
| state.append(chat_history) | |
| return result, state | |
| def predict_wrapper(self, input_text=None, input_audio=None): | |
| if input_audio is not None: | |
| result, state = self.predict( | |
| input_data=(input_audio,), input_type='audio') | |
| else: | |
| result, state = self.predict( | |
| input_data=('', input_text), input_type='text') | |
| response = result[0][1][:195] | |
| avatar = self.invoke_exh_api(response) | |
| return result,avatar | |
| man_chatbot = ChatbotClass() | |
| iface = gr.Interface( | |
| fn=man_chatbot.predict_wrapper, | |
| inputs=[gr.inputs.Textbox(label="Text Input"), | |
| gr.inputs.Audio(source="microphone", type='filepath')], | |
| outputs=[gr.outputs.Textbox(label="Result"), | |
| gr.outputs.Video().style(height=100, container=True)] | |
| ) | |
| iface.launch() | |
| ''' | |
| iface.launch() | |
| with gr.Blocks() as demo: | |
| chatbot = gr.Chatbot() | |
| state = gr.State([]) | |
| with gr.Row(): | |
| txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style( | |
| container=False) | |
| k_slider = gr.Slider(minimum=1, maximum=10, default=4,label='k') | |
| txt.submit(man_chatbot.predict, [txt, state,k_slider],[chatbot,state]) | |
| demo.launch() | |
| ''' |