jawakja commited on
Commit
99f77c1
·
verified ·
1 Parent(s): 83b24f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +178 -91
app.py CHANGED
@@ -5,21 +5,50 @@ import cv2
5
  import os
6
  import tempfile
7
  import shutil
8
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
9
  from sentence_transformers import SentenceTransformer
10
  import faiss
11
 
12
- # Load Qwen-VL-Chat
13
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
14
- model = AutoModelForCausalLM.from_pretrained(
15
- "Qwen/Qwen-VL-Chat",
16
- device_map="auto",
17
- torch_dtype=torch.bfloat16,
18
- trust_remote_code=True
19
- ).eval()
20
 
21
- # Embedding model
22
- embed_model = SentenceTransformer('all-MiniLM-L6-v2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Global state for FAISS
25
  chunks = []
@@ -27,98 +56,153 @@ index = None
27
 
28
  # PDF processing
29
  def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200):
30
- doc = fitz.open(pdf_path)
31
- text = ""
32
- for page in doc:
33
- text += page.get_text()
34
- return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
 
 
 
 
35
 
36
  def build_faiss_index(chunks):
37
- embeddings = embed_model.encode(chunks, convert_to_numpy=True)
38
- dim = embeddings.shape[1]
39
- idx = faiss.IndexFlatL2(dim)
40
- idx.add(embeddings)
41
- return idx
 
 
 
 
 
 
42
 
43
  def rag_query(query, chunks, index, top_k=3):
44
- q_emb = embed_model.encode([query], convert_to_numpy=True)
45
- D, I = index.search(q_emb, top_k)
46
- return "\n\n".join([chunks[i] for i in I[0]])
 
 
 
 
 
 
47
 
48
- # Vision/Text chat
49
  def chat_with_qwen(text=None, image=None):
50
- elements = []
51
- if image:
52
- elements.append({"image": image})
53
- if text:
54
- elements.append({"text": text})
55
- if not elements:
56
- return "Please upload or type something."
57
- query = tokenizer.from_list_format(elements)
58
- response, _ = model.chat(tokenizer, query, history=None)
59
- return response
60
-
61
- # Video frame extraction
62
- def extract_video_frames(video_path, max_frames=3):
63
- cap = cv2.VideoCapture(video_path)
64
- frames, count = [], 0
65
- while len(frames) < max_frames:
66
- success, frame = cap.read()
67
- if not success:
68
- break
69
- frames.append(frame)
70
- count += 1
71
- cap.set(cv2.CAP_PROP_POS_FRAMES, count * 30)
72
- cap.release()
73
- return frames
74
-
75
- # Main chatbot logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def multimodal_chat(message, history, image=None, video=None, pdf=None):
77
  global chunks, index
 
 
 
78
 
79
- # PDF-based RAG
80
- if pdf:
81
- chunks = extract_chunks_from_pdf(pdf.name)
82
- index = build_faiss_index(chunks)
83
- context = rag_query(message, chunks, index)
84
- final_prompt = f"Context:\n{context}\n\nQuestion: {message}"
85
- response = chat_with_qwen(final_prompt)
86
- return response
 
 
 
 
87
 
88
- # Image
89
- if image:
90
- response = chat_with_qwen(message, image)
91
- return response
92
 
93
- # Video (extract frames and send all in one call)
94
- if video:
95
- temp_dir = tempfile.mkdtemp()
96
- video_path = os.path.join(temp_dir, "vid.mp4")
97
- shutil.copy(video, video_path)
98
- frames = extract_video_frames(video_path)
99
-
100
- # Save and collect image paths
101
- images = []
102
- for i, frame in enumerate(frames):
103
- temp_img_path = os.path.join(temp_dir, f"frame_{i}.jpg")
104
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
- cv2.imwrite(temp_img_path, frame_rgb)
106
- images.append(temp_img_path)
107
-
108
- # Combine all frames and text into one query
109
- elements = [{"image": img} for img in images]
110
- if message:
111
- elements.append({"text": message})
112
 
113
- query = tokenizer.from_list_format(elements)
114
- response, _ = model.chat(tokenizer, query, history=None)
115
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
- # Text only
118
- if message:
119
- return chat_with_qwen(message)
120
 
121
- return "Please input a message, image, video, or PDF."
 
 
 
122
 
123
  # ---- Gradio UI ---- #
124
  with gr.Blocks(css="""
@@ -148,7 +232,7 @@ padding: 16px;
148
  footer {display: none !important;}
149
  """) as demo:
150
  gr.Markdown(
151
- "<h1 style='text-align: center;'>Multimodal Chatbot powered by LLAVACMVRL and QWEN-VL</h1>"
152
  "<p style='text-align: center;'>Ask questions with text, images, videos, or PDFs in a smart and multimodal way.</p>"
153
  )
154
 
@@ -165,6 +249,8 @@ footer {display: none !important;}
165
  pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF")
166
 
167
  def user_send(message, history, image, video, pdf):
 
 
168
  response = multimodal_chat(message, history, image, video, pdf)
169
  history.append((message, response))
170
  return "", history
@@ -172,5 +258,6 @@ footer {display: none !important;}
172
  send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
173
  txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
174
 
175
- # Launch the app
 
176
  demo.launch()
 
5
  import os
6
  import tempfile
7
  import shutil
8
+ import logging
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
10
  from sentence_transformers import SentenceTransformer
11
  import faiss
12
 
13
+ # Setup logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
 
 
 
 
 
16
 
17
+ # Check available resources
18
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
19
+ if torch.cuda.is_available():
20
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
21
+ logger.info(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9} GB")
22
+
23
+ # Configure quantization for lower memory usage
24
+ bnb_config = BitsAndBytesConfig(
25
+ load_in_4bit=True,
26
+ bnb_4bit_quant_type="nf4",
27
+ bnb_4bit_compute_dtype=torch.float16,
28
+ )
29
+
30
+ try:
31
+ # Load Qwen-2.5-Omni-3B with memory optimizations
32
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Omni-3B", trust_remote_code=True)
33
+ model = AutoModelForCausalLM.from_pretrained(
34
+ "Qwen/Qwen2.5-Omni-3B",
35
+ device_map="auto",
36
+ quantization_config=bnb_config,
37
+ trust_remote_code=True
38
+ ).eval()
39
+ logger.info("Model loaded successfully")
40
+ except Exception as e:
41
+ logger.error(f"Error loading model: {e}")
42
+ model = None
43
+ tokenizer = None
44
+
45
+ # Use a smaller embedding model
46
+ try:
47
+ embed_model = SentenceTransformer('paraphrase-MiniLM-L3-v2')
48
+ logger.info("Embedding model loaded successfully")
49
+ except Exception as e:
50
+ logger.error(f"Error loading embedding model: {e}")
51
+ embed_model = None
52
 
53
  # Global state for FAISS
54
  chunks = []
 
56
 
57
  # PDF processing
58
  def extract_chunks_from_pdf(pdf_path, chunk_size=1000, overlap=200):
59
+ try:
60
+ doc = fitz.open(pdf_path)
61
+ text = ""
62
+ for page in doc:
63
+ text += page.get_text()
64
+ return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size - overlap)]
65
+ except Exception as e:
66
+ logger.error(f"PDF extraction error: {e}")
67
+ return ["Error extracting PDF content"]
68
 
69
  def build_faiss_index(chunks):
70
+ try:
71
+ if not embed_model:
72
+ return None
73
+ embeddings = embed_model.encode(chunks, convert_to_numpy=True)
74
+ dim = embeddings.shape[1]
75
+ idx = faiss.IndexFlatL2(dim)
76
+ idx.add(embeddings)
77
+ return idx
78
+ except Exception as e:
79
+ logger.error(f"FAISS index error: {e}")
80
+ return None
81
 
82
  def rag_query(query, chunks, index, top_k=3):
83
+ if not index or not embed_model:
84
+ return "Embedding model not available"
85
+ try:
86
+ q_emb = embed_model.encode([query], convert_to_numpy=True)
87
+ D, I = index.search(q_emb, top_k)
88
+ return "\n\n".join([chunks[i] for i in I[0]])
89
+ except Exception as e:
90
+ logger.error(f"RAG query error: {e}")
91
+ return "Error retrieving context"
92
 
93
+ # Vision/Text chat with Qwen-2.5-Omni
94
  def chat_with_qwen(text=None, image=None):
95
+ if not model or not tokenizer:
96
+ return "Model failed to load due to resource constraints. Try a smaller model or upgrade your space."
97
+
98
+ try:
99
+ # For Qwen-2.5-Omni-3B
100
+ messages = []
101
+
102
+ if image:
103
+ # Add the image as a message
104
+ messages.append({"role": "user", "content": [
105
+ {"image": image},
106
+ {"text": text if text else "Please describe this image."}
107
+ ]})
108
+ else:
109
+ # Text-only query
110
+ messages.append({"role": "user", "content": text})
111
+
112
+ # Generate response
113
+ response = model.chat(tokenizer, messages)
114
+ return response
115
+ except Exception as e:
116
+ logger.error(f"Chat error: {e}")
117
+ return f"Error generating response: {str(e)}"
118
+
119
+ # Video frame extraction - more memory efficient
120
+ def extract_video_frames(video_path, max_frames=2):
121
+ try:
122
+ cap = cv2.VideoCapture(video_path)
123
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
124
+ frames = []
125
+
126
+ # Take fewer, evenly distributed frames
127
+ if total_frames > 0:
128
+ frame_indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
129
+ for idx in frame_indices:
130
+ cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
131
+ success, frame = cap.read()
132
+ if success:
133
+ frames.append(frame)
134
+ cap.release()
135
+ return frames
136
+ except Exception as e:
137
+ logger.error(f"Video frame extraction error: {e}")
138
+ return []
139
+
140
+ # Main chatbot logic with error handling
141
  def multimodal_chat(message, history, image=None, video=None, pdf=None):
142
  global chunks, index
143
+
144
+ if not model:
145
+ return "Model not loaded due to memory constraints. Try upgrading your Hugging Face space."
146
 
147
+ try:
148
+ # PDF-based RAG
149
+ if pdf:
150
+ chunks = extract_chunks_from_pdf(pdf.name)
151
+ index = build_faiss_index(chunks)
152
+ if index:
153
+ context = rag_query(message, chunks, index)
154
+ final_prompt = f"I'll provide some context, then ask a question. Context:\n{context}\n\nQuestion: {message}"
155
+ response = chat_with_qwen(final_prompt)
156
+ else:
157
+ response = "Could not process PDF due to resource constraints"
158
+ return response
159
 
160
+ # Image
161
+ if image:
162
+ response = chat_with_qwen(message, image)
163
+ return response
164
 
165
+ # Video (extract frames and process one by one)
166
+ if video:
167
+ temp_dir = tempfile.mkdtemp()
168
+ try:
169
+ video_path = os.path.join(temp_dir, "vid.mp4")
170
+ shutil.copy(video, video_path)
171
+ frames = extract_video_frames(video_path)
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
+ # Only process if we got frames
174
+ if frames:
175
+ # Save frames and process them
176
+ frame_descriptions = []
177
+ for i, frame in enumerate(frames):
178
+ temp_img_path = os.path.join(temp_dir, f"frame_{i}.jpg")
179
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
180
+ cv2.imwrite(temp_img_path, frame_rgb)
181
+
182
+ # Get description for this frame
183
+ frame_query = "Describe this video frame in detail."
184
+ frame_description = chat_with_qwen(frame_query, temp_img_path)
185
+ frame_descriptions.append(f"Frame {i+1}: {frame_description}")
186
+
187
+ # Combine frame descriptions and answer the user's question
188
+ combined_context = "\n\n".join(frame_descriptions)
189
+ final_prompt = f"I analyzed some video frames and here's what I found:\n\n{combined_context}\n\nBased on these video frames, {message if message else 'please describe what's happening in this video.'}"
190
+ response = chat_with_qwen(final_prompt)
191
+ return response
192
+ else:
193
+ return "Could not extract video frames"
194
+ finally:
195
+ # Cleanup temp files
196
+ shutil.rmtree(temp_dir, ignore_errors=True)
197
 
198
+ # Text only
199
+ if message:
200
+ return chat_with_qwen(message)
201
 
202
+ return "Please input a message, image, video, or PDF."
203
+ except Exception as e:
204
+ logger.error(f"General error in multimodal_chat: {e}")
205
+ return f"Error processing your request: {str(e)}. This may be due to memory constraints."
206
 
207
  # ---- Gradio UI ---- #
208
  with gr.Blocks(css="""
 
232
  footer {display: none !important;}
233
  """) as demo:
234
  gr.Markdown(
235
+ "<h1 style='text-align: center;'>Multimodal Chatbot powered by Qwen-2.5-Omni-3B</h1>"
236
  "<p style='text-align: center;'>Ask questions with text, images, videos, or PDFs in a smart and multimodal way.</p>"
237
  )
238
 
 
249
  pdf_input = gr.File(file_types=[".pdf"], label="Upload PDF")
250
 
251
  def user_send(message, history, image, video, pdf):
252
+ if not message and not image and not video and not pdf:
253
+ return "", history
254
  response = multimodal_chat(message, history, image, video, pdf)
255
  history.append((message, response))
256
  return "", history
 
258
  send_btn.click(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
259
  txt.submit(user_send, [txt, state, image_input, video_input, pdf_input], [txt, chatbot])
260
 
261
+ # Launch the app with memory logging
262
+ logger.info("Starting Gradio app")
263
  demo.launch()