Kal1510 commited on
Commit
94e4aac
·
verified ·
1 Parent(s): 3f639cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +204 -0
app.py CHANGED
@@ -1,3 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
2
  with gr.Row():
3
  gr.Markdown("""
@@ -79,4 +274,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
79
  None,
80
  chatbot,
81
  queue=False
 
 
 
 
 
 
 
 
 
82
  )
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from PyPDF2 import PdfReader
5
+ from transformers import (
6
+ AutoTokenizer, pipeline,
7
+ AutoModelForCausalLM, AutoConfig,
8
+ BitsAndBytesConfig
9
+ )
10
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
11
+ from langchain_community.vectorstores import FAISS
12
+ from langchain.prompts import PromptTemplate
13
+ from langchain.chains import LLMChain
14
+ from langchain.embeddings import HuggingFaceEmbeddings
15
+ from langchain.schema import Document
16
+ from langchain import HuggingFacePipeline
17
+
18
+ # ------------------------------
19
+ # Device setup
20
+ # ------------------------------
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
+ # ------------------------------
24
+ # Embedding model config
25
+ # ------------------------------
26
+ modelPath = "sentence-transformers/all-mpnet-base-v2"
27
+ model_kwargs = {"device": str(device)}
28
+ encode_kwargs = {"normalize_embedding": False}
29
+
30
+ embeddings = HuggingFaceEmbeddings(
31
+ model_name=modelPath,
32
+ model_kwargs=model_kwargs,
33
+ encode_kwargs=encode_kwargs
34
+ )
35
+
36
+ # ------------------------------
37
+ # Load Mistral model in 4bit
38
+ # ------------------------------
39
+ model_name = "mistralai/Mistral-7B-Instruct-v0.1"
40
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
41
+ tokenizer.pad_token = tokenizer.eos_token
42
+ tokenizer.padding_side = "right"
43
+
44
+ # 4-bit quantization config
45
+ bnb_config = BitsAndBytesConfig(
46
+ load_in_4bit=True,
47
+ bnb_4bit_quant_type="nf4",
48
+ bnb_4bit_use_double_quant=True,
49
+ bnb_4bit_compute_dtype=torch.float16
50
+ )
51
+
52
+ # Load model
53
+ model = AutoModelForCausalLM.from_pretrained(
54
+ model_name,
55
+ quantization_config=bnb_config,
56
+ device_map="auto"
57
+ )
58
+
59
+ # ------------------------------
60
+ # Improved Text Generation Pipeline
61
+ # ------------------------------
62
+ text_generation = pipeline(
63
+ model=model,
64
+ tokenizer=tokenizer,
65
+ task="text-generation",
66
+ temperature=0.7,
67
+ top_p=0.9,
68
+ top_k=50,
69
+ repetition_penalty=1.1,
70
+ return_full_text=False,
71
+ max_new_tokens=2000,
72
+ do_sample=True,
73
+ eos_token_id=tokenizer.eos_token_id,
74
+ )
75
+
76
+ # Wrap in LangChain interface
77
+ mistral_llm = HuggingFacePipeline(pipeline=text_generation)
78
+
79
+ # ------------------------------
80
+ # PDF Processing Functions
81
+ # ------------------------------
82
+ def pdf_text(pdf_docs):
83
+ text = ""
84
+ for doc in pdf_docs:
85
+ reader = PdfReader(doc)
86
+ for page in reader.pages:
87
+ page_text = page.extract_text()
88
+ if page_text:
89
+ text += page_text + "\n"
90
+ return text
91
+
92
+ def get_chunks(text):
93
+ splitter = RecursiveCharacterTextSplitter(
94
+ chunk_size=1000,
95
+ chunk_overlap=200,
96
+ length_function=len
97
+ )
98
+ chunks = splitter.split_text(text)
99
+ return [Document(page_content=chunk) for chunk in chunks]
100
+
101
+ def get_vectorstore(documents):
102
+ db = FAISS.from_documents(documents, embedding=embeddings)
103
+ db.save_local("faiss_index")
104
+
105
+ # ------------------------------
106
+ # Conversational Prompt Template
107
+ # ------------------------------
108
+ def get_qa_prompt():
109
+ prompt_template = """<s>[INST]
110
+ You are a helpful, knowledgeable AI assistant. Answer the user's question based on the provided context.
111
+
112
+ Guidelines:
113
+ - Respond in a natural, conversational tone
114
+ - Be detailed but concise
115
+ - Use paragraphs and bullet points when appropriate
116
+ - If you don't know, say so
117
+ - Maintain a friendly and professional demeanor
118
+
119
+ Conversation History:
120
+ {chat_history}
121
+
122
+ Relevant Context:
123
+ {context}
124
+
125
+ Current Question: {question}
126
+
127
+ Provide a helpful response: [/INST]"""
128
+
129
+ return PromptTemplate(
130
+ template=prompt_template,
131
+ input_variables=["context", "question", "chat_history"]
132
+ )
133
+
134
+ # ------------------------------
135
+ # Chat Handling Functions
136
+ # ------------------------------
137
+ def handle_pdf_upload(pdf_files):
138
+ try:
139
+ if not pdf_files:
140
+ return "⚠️ Please upload at least one PDF file"
141
+
142
+ text = pdf_text(pdf_files)
143
+ if not text.strip():
144
+ return "⚠️ Could not extract text from PDFs - please try different files"
145
+
146
+ chunks = get_chunks(text)
147
+ get_vectorstore(chunks)
148
+ return f"✅ Processed {len(pdf_files)} PDF(s) with {len(chunks)} text chunks"
149
+ except Exception as e:
150
+ return f"❌ Error: {str(e)}"
151
+
152
+ def format_chat_history(chat_history):
153
+ return "\n".join([f"User: {q}\nAssistant: {a}" for q, a in chat_history[-3:]])
154
+
155
+ def user_query(msg, chat_history):
156
+ if not os.path.exists("faiss_index"):
157
+ chat_history.append((msg, "Please upload PDF documents first so I can help you."))
158
+ return "", chat_history
159
+
160
+ try:
161
+ # Load vector store
162
+ db = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
163
+ retriever = db.as_retriever(search_kwargs={"k": 3})
164
+
165
+ # Get relevant context
166
+ docs = retriever.get_relevant_documents(msg)
167
+ context = "\n\n".join([d.page_content for d in docs])
168
+
169
+ # Generate response
170
+ prompt = get_qa_prompt()
171
+ chain = LLMChain(llm=mistral_llm, prompt=prompt)
172
+
173
+ response = chain.run({
174
+ "question": msg,
175
+ "context": context,
176
+ "chat_history": format_chat_history(chat_history)
177
+ })
178
+
179
+ # Clean response
180
+ response = response.strip()
181
+ for end_token in ["</s>", "[INST]", "[/INST]"]:
182
+ if response.endswith(end_token):
183
+ response = response[:-len(end_token)].strip()
184
+
185
+ chat_history.append((msg, response))
186
+ return "", chat_history
187
+
188
+ except Exception as e:
189
+ error_msg = f"Sorry, I encountered an error: {str(e)}"
190
+ chat_history.append((msg, error_msg))
191
+ return "", chat_history
192
+
193
+ # ------------------------------
194
+ # Gradio Interface
195
+ # ------------------------------
196
  with gr.Blocks(theme=gr.themes.Soft(), title="PDF Chat Assistant") as demo:
197
  with gr.Row():
198
  gr.Markdown("""
 
274
  None,
275
  chatbot,
276
  queue=False
277
+ )
278
+
279
+ # Launch the app
280
+ if __name__ == "__main__":
281
+ demo.launch(
282
+ server_name="0.0.0.0",
283
+ server_port=7861,
284
+ share=True,
285
+ debug=True
286
  )