Spaces:
Paused
Paused
import plotly.graph_objs as go | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
import plotly.express as px | |
import numpy as np | |
import os | |
import pprint | |
import codecs | |
import chardet | |
import gradio as gr | |
from langchain.llms import HuggingFacePipeline, OpenAIChat | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain import OpenAI, ConversationChain, LLMChain, PromptTemplate | |
from langchain.chains.conversation.memory import ConversationBufferMemory | |
from EdgeGPT import Chatbot | |
import whisper | |
from datetime import datetime | |
import json | |
import requests | |
from langchain.chains.question_answering import load_qa_chain | |
import langchain | |
class ChatbotClass: | |
def __init__(self): | |
FOLDER_PATH = './data/eqe-manual' | |
QUERY = 'How do I charge my vehicle?' | |
K = 10 | |
self.whisper_model = whisper.load_model(name='tiny') | |
self.embeddings = HuggingFaceEmbeddings() | |
self.index = FAISS.load_local( | |
folder_path=FOLDER_PATH, embeddings=self.embeddings | |
) | |
self.llm = OpenAIChat(temperature=0) | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", input_key="human_input", return_messages=True | |
) | |
self.keyword_chain = self.init_keyword_chain() | |
self.context_chain = self.init_context_chain() | |
self.document_retrieval_chain = self.init_document_retrieval() | |
self.conversation_chain = self.init_conversation() | |
def format_history(self, memory): | |
history = memory.chat_memory.messages | |
if len(history) == 0: | |
return [] | |
formatted_history = [] | |
for h in history: | |
if isinstance(h, langchain.schema.HumanMessage): | |
user_response = h.content | |
elif isinstance(h, langchain.schema.AIMessage): | |
ai_response = h.content | |
formatted_history.append((user_response, ai_response)) | |
return formatted_history | |
def init_document_retrieval(self): | |
retrieve_documents_template = """This function retrieves exerts from a Vehicle Owner's Manual. The function is useful for adding vehicle-specific context to answer questions. Based on a request, determine if vehicle specific information is needed. Respond with "Yes" or "No". If the answer is both, respond with "Yes":\nrequest: How do I change the tire?\nresponse: Yes\nrequest: Hello\nresponse: No\nrequest: I was in an accident. What should I do?\nresponse: Yes\nrequest: {request}\nresponse:""" | |
prompt_template = PromptTemplate( | |
input_variables=["request"], | |
template=retrieve_documents_template | |
) | |
document_retrieval_chain = LLMChain( | |
llm=self.llm, prompt=prompt_template, verbose=True | |
) | |
return document_retrieval_chain | |
def init_keyword_chain(self): | |
keyword_template = """You are a vehicle owner searching for content in your vehicle's owner manual. Your job is to come up with keywords to use when searching inside your manual, based on a question you have. | |
Question: {question} | |
Keywords:""" | |
prompt_template = PromptTemplate( | |
input_variables=["question"], template=keyword_template | |
) | |
keyword_chain = LLMChain( | |
llm=self.llm, prompt=prompt_template, verbose=True) | |
return keyword_chain | |
def init_context_chain(self): | |
context_template = """You are a friendly and helpful chatbot having a conversation with a human. | |
Given the following extracted parts of a long document and a question, create a final answer. | |
{context} | |
{chat_history} | |
Human: {human_input} | |
Chatbot:""" | |
context_prompt = PromptTemplate( | |
input_variables=["chat_history", "human_input", "context"], | |
template=context_template | |
) | |
self.memory = ConversationBufferMemory( | |
memory_key="chat_history", input_key="human_input", return_messages=True | |
) | |
context_chain = load_qa_chain( | |
self.llm, chain_type="stuff", memory=self.memory, prompt=context_prompt | |
) | |
return context_chain | |
def init_conversation(self): | |
template = """You are a chatbot having a conversation with a human. | |
{chat_history} | |
Human: {human_input} | |
Chatbot:""" | |
prompt = PromptTemplate( | |
input_variables=["chat_history", "human_input"], | |
template=template | |
) | |
conversation_chain = LLMChain( | |
llm=self.llm, | |
prompt=prompt, | |
verbose=True, | |
memory=self.memory, | |
) | |
return conversation_chain | |
def transcribe_audio(self, audio_file, model): | |
result = self.whisper_model.transcribe(audio_file) | |
return result['text'] | |
def ask_question(self, query, k=4): | |
tool_usage = self.document_retrieval_chain.run(query) | |
print('\033[1;32m' f'search manual?: {tool_usage}' "\033[0m") | |
chat_history = self.format_history(self.memory) | |
if tool_usage == 'Yes': | |
keywords = self.keyword_chain.run(question=query) | |
print('\033[1;32m' f'keywords:{keywords}' "\033[0m") | |
context = self.index.similarity_search(query=keywords, k=k) | |
result = self.context_chain.run( | |
input_documents=context, human_input=query | |
) | |
else: | |
result = self.conversation_chain.run(query) | |
return [(query, result)], chat_history | |
def invoke_exh_api(self, bot_response, bot_name='Zippy', voice_name='Fiona', idle_url='https://ugc-idle.s3-us-west-2.amazonaws.com/4a6a607a466bdf6605bbd97ef146751b.mp4', animation_pipeline='high_quality', bearer_token='eyJhbGciOiJIUzUxMiJ9.eyJ1c2VybmFtZSI6IndlYiJ9.LSzIQx6h61l5FXs52s0qcY8WqauET6z9nnxgSzvoNBx8RYEKm8OpOohcK8wjuwteV4ZGug4NOjoGQoUZIKH84A'): | |
if len(bot_response) > 200: | |
print('Input is over 200 characters. Shorten the message') | |
url = 'https://api.exh.ai/animations/v1/generate_lipsync' | |
headers = { | |
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.0.0 Safari/537.36 Edg/110.0.1587.46', | |
'authority': 'api.exh.ai', | |
'accept': '*/*', | |
'accept-encoding': 'gzip, deflate, br', | |
'accept-language': 'en-US,en;q=0.9', | |
'authorization': f'Bearer {bearer_token}', | |
'content-type': 'application/json', | |
'origin': 'https://admin.exh.ai', | |
'referer': 'https://admin.exh.ai/', | |
'sec-ch-ua': '"Chromium";v="110", "Not A(Brand";v="24", "Microsoft Edge";v="110"', | |
'sec-ch-ua-mobile': '?0', | |
'sec-ch-ua-platform': '"Windows"', | |
'sec-fetch-dest': 'empty', | |
'sec-fetch-mode': 'cors', | |
'sec-fetch-site': 'same-site', | |
} | |
data = { | |
'bot_name': bot_name, | |
'bot_response': bot_response, | |
'voice_name': voice_name, | |
'idle_url': idle_url, | |
'animation_pipeline': animation_pipeline, | |
} | |
r = requests.post(url, headers=headers, data=json.dumps(data)) | |
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S%f') | |
outfile = f'talking_head_{timestamp}.mp4' | |
with open(outfile, 'wb') as f: | |
f.write(r.content) | |
return outfile | |
def predict(self, input_data, state=[], k=4, input_type='audio'): | |
if input_type == 'audio': | |
txt = self.transcribe_audio(input_data[0], self.whisper_model) | |
else: | |
txt = input_data[1] | |
result, chat_history = self.ask_question(txt, k=k) | |
state.append(chat_history) | |
return result, state | |
def predict_wrapper(self, input_text=None, input_audio=None): | |
if input_audio is not None: | |
result, state = self.predict( | |
input_data=(input_audio,), input_type='audio') | |
else: | |
result, state = self.predict( | |
input_data=('', input_text), input_type='text') | |
response = result[0][1][:195] | |
avatar = self.invoke_exh_api(response) | |
return result,avatar | |
man_chatbot = ChatbotClass() | |
iface = gr.Interface( | |
fn=man_chatbot.predict_wrapper, | |
inputs=[gr.inputs.Textbox(label="Text Input"), | |
gr.inputs.Audio(source="microphone", type='filepath')], | |
outputs=[gr.outputs.Textbox(label="Result"), | |
gr.outputs.Video().style(width=360, height=360, container=True)] | |
) | |
iface.launch() | |
''' | |
iface.launch() | |
with gr.Blocks() as demo: | |
chatbot = gr.Chatbot() | |
state = gr.State([]) | |
with gr.Row(): | |
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter").style( | |
container=False) | |
k_slider = gr.Slider(minimum=1, maximum=10, default=4,label='k') | |
txt.submit(man_chatbot.predict, [txt, state,k_slider],[chatbot,state]) | |
demo.launch() | |
''' |