|
import os |
|
import torch |
|
import gradio as gr |
|
from PyPDF2 import PdfReader |
|
from transformers import ( |
|
AutoTokenizer, pipeline, |
|
AutoModelForCausalLM, AutoConfig, |
|
BitsAndBytesConfig |
|
) |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.vectorstores import FAISS |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains import LLMChain |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.schema import Document |
|
from langchain import HuggingFacePipeline |
|
|
|
api_key=os.getenv("api_key") |
|
|
|
try: |
|
login(token=api_key) |
|
except Exception as e: |
|
print(f"Login failed: {e}") |
|
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
|
modelPath = "sentence-transformers/all-mpnet-base-v2" |
|
model_kwargs = {"device": str(device)} |
|
encode_kwargs = {"normalize_embedding": False} |
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name=modelPath, |
|
model_kwargs=model_kwargs, |
|
encode_kwargs=encode_kwargs |
|
) |
|
|
|
|
|
|
|
|
|
model_name = "mistralai/Mistral-7B-Instruct-v0.1" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
tokenizer.padding_side = "right" |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.float16 |
|
) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
quantization_config=bnb_config, |
|
device_map="auto" |
|
) |
|
|
|
|
|
|
|
|
|
text_generation = pipeline( |
|
model=model, |
|
tokenizer=tokenizer, |
|
task="text-generation", |
|
temperature=0.7, |
|
top_p=0.9, |
|
top_k=50, |
|
repetition_penalty=1.1, |
|
return_full_text=False, |
|
max_new_tokens=2000, |
|
do_sample=True, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
|
|
mistral_llm = HuggingFacePipeline(pipeline=text_generation) |
|
|
|
|
|
|
|
|
|
def pdf_text(pdf_docs): |
|
text = "" |
|
for doc in pdf_docs: |
|
reader = PdfReader(doc) |
|
for page in reader.pages: |
|
page_text = page.extract_text() |
|
if page_text: |
|
text += page_text + "\n" |
|
return text |
|
|
|
def get_chunks(text): |
|
splitter = RecursiveCharacterTextSplitter( |
|
chunk_size=1000, |
|
chunk_overlap=200, |
|
length_function=len |
|
) |
|
chunks = splitter.split_text(text) |
|
return [Document(page_content=chunk) for chunk in chunks] |
|
|
|
def get_vectorstore(documents): |
|
db = FAISS.from_documents(documents, embedding=embeddings) |
|
db.save_local("faiss_index") |
|
|
|
|
|
|
|
|
|
def get_qa_prompt(): |
|
prompt_template = """<s>[INST] |
|
You are a helpful, knowledgeable AI assistant. Answer the user's question based on the provided context. |
|
|
|
Guidelines: |
|
- Respond in a natural, conversational tone |
|
- Be detailed but concise |
|
- Use paragraphs and bullet points when appropriate |
|
- If you don't know, say so |
|
- Maintain a friendly and professional demeanor |
|
|
|
Conversation History: |
|
{chat_history} |
|
|
|
Relevant Context: |
|
{context} |
|
|
|
Current Question: {question} |
|
|
|
Provide a helpful response: [/INST]""" |
|
|
|
return PromptTemplate( |
|
template=prompt_template, |
|
input_variables=["context", "question", "chat_history"] |
|
) |
|
|
|
|
|
|
|
|
|
def handle_pdf_upload(pdf_files): |
|
try: |
|
if not pdf_files: |
|
return "โ ๏ธ Please upload at least one PDF file" |
|
|
|
text = pdf_text(pdf_files) |
|
if not text.strip(): |
|
return "โ ๏ธ Could not extract text from PDFs - please try different files" |
|
|
|
chunks = get_chunks(text) |
|
get_vectorstore(chunks) |
|
return f"โ
Processed {len(pdf_files)} PDF(s) with {len(chunks)} text chunks" |
|
except Exception as e: |
|
return f"โ Error: {str(e)}" |
|
|
|
def format_chat_history(chat_history): |
|
return "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history[-3:]]) |
|
|
|
def user_query(msg, chat_history): |
|
if not os.path.exists("faiss_index"): |
|
chat_history.append((msg, "Please upload PDF documents first so I can help you.")) |
|
return "", chat_history |
|
|
|
try: |
|
|
|
db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True) |
|
retriever = db.as_retriever(search_kwargs={"k": 3}) |
|
|
|
|
|
docs = retriever.get_relevant_documents(msg) |
|
context = "\n\n".join([d.page_content for d in docs]) |
|
|
|
|
|
prompt = get_qa_prompt() |
|
chain = LLMChain(llm=mistral_llm, prompt=prompt) |
|
|
|
response = chain.run({ |
|
"question": msg, |
|
"context": context, |
|
"chat_history": format_chat_history(chat_history) |
|
}) |
|
|
|
|
|
response = response.strip() |
|
for end_token in ["</s>", "[INST]", "[/INST]"]: |
|
if response.endswith(end_token): |
|
response = response[:-len(end_token)].strip() |
|
|
|
chat_history.append((msg, response)) |
|
return "", chat_history |
|
|
|
except Exception as e: |
|
error_msg = f"Sorry, I encountered an error: {str(e)}" |
|
chat_history.append((msg, error_msg)) |
|
return "", chat_history |
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo: |
|
with gr.Row(): |
|
gr.Markdown(""" |
|
# ๐ PDF Chat Assistant |
|
### Have natural conversations with your documents |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1, min_width=300): |
|
gr.Markdown("### Document Upload") |
|
pdf_input = gr.File( |
|
file_types=[".pdf"], |
|
file_count="multiple", |
|
label="Upload PDFs", |
|
height=100 |
|
) |
|
upload_btn = gr.Button("Process Documents", variant="primary") |
|
status_box = gr.Textbox(label="Status", interactive=False) |
|
gr.Markdown(""" |
|
**Instructions:** |
|
1. Upload PDF documents |
|
2. Click Process Documents |
|
3. Start chatting in the right panel |
|
""") |
|
|
|
with gr.Column(scale=2): |
|
chatbot = gr.Chatbot( |
|
height=600, |
|
bubble_full_width=False, |
|
avatar_images=( |
|
"user.png", |
|
"bot.png" |
|
) |
|
) |
|
|
|
with gr.Row(): |
|
message = gr.Textbox( |
|
placeholder="Type your question about the documents...", |
|
show_label=False, |
|
container=False, |
|
scale=7, |
|
autofocus=True |
|
) |
|
submit_btn = gr.Button("Send", variant="primary", scale=1) |
|
|
|
with gr.Row(): |
|
clear_chat = gr.Button("๐งน Clear Conversation") |
|
examples = gr.Examples( |
|
examples=[ |
|
"Summarize the key points from the documents", |
|
"What are the main findings?", |
|
"Explain this in simpler terms" |
|
], |
|
inputs=message, |
|
label="Example Questions" |
|
) |
|
|
|
|
|
upload_btn.click( |
|
fn=handle_pdf_upload, |
|
inputs=pdf_input, |
|
outputs=status_box |
|
) |
|
|
|
submit_btn.click( |
|
fn=user_query, |
|
inputs=[message, chatbot], |
|
outputs=[message, chatbot] |
|
) |
|
|
|
message.submit( |
|
fn=user_query, |
|
inputs=[message, chatbot], |
|
outputs=[message, chatbot] |
|
) |
|
|
|
clear_chat.click( |
|
lambda: [], |
|
None, |
|
chatbot, |
|
queue=False |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7861, |
|
share=True, |
|
debug=True |
|
) |