Check / app.py
Rajesh3338's picture
Update app.py
0b7f75f verified
raw
history blame
2.85 kB
import gradio as gr
import spaces
import torch
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Load and process documents
doc_loader = TextLoader("dataset.txt")
docs = doc_loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(docs)
# Create vector database
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(split_docs, embeddings)
# Load model and create pipeline
model_name = "01-ai/Yi-Coder-9B-Chat"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
#model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype="auto")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map=device,
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
#model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
qa_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=250,
pad_token_id=tokenizer.eos_token_id
)
# Set up LangChain
llm = HuggingFacePipeline(pipeline=qa_pipeline)
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
qa_chain = RetrievalQA.from_chain_type(
retriever=retriever,
chain_type="stuff",
llm=llm,
return_source_documents=False
)
@spaces.GPU
def preprocess_query(query):
if "script" in query or "code" in query.lower():
return f"Write a CPSL script: {query}"
return query
@spaces.GPU
def clean_response(response):
result = response.get("result", "")
if "Answer:" in result:
return result.split("Answer:")[1].strip()
return result.strip()
@spaces.GPU
def chatbot_response(user_input):
processed_query = preprocess_query(user_input)
raw_response = qa_chain.invoke({"query": processed_query})
return clean_response(raw_response)
with gr.Blocks() as demo: # Removed @spaces.GPU here
gr.Markdown("# CPSL Chatbot")
chat_history = gr.Chatbot()
user_input = gr.Textbox(label="Your Message:")
send_button = gr.Button("Send")
@spaces.GPU
def interact(user_message, history):
bot_reply = chatbot_response(user_message)
history.append((user_message, bot_reply))
return history, history
send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])
demo.launch()