Spaces:
Configuration error
Configuration error
Orel MAZOR
commited on
Commit
·
e4bc671
1
Parent(s):
0d0725b
Commit 1
Browse files- .DS_Store +0 -0
- .gradio/certificate.pem +31 -0
- .vscode/launch.json +15 -0
- __pycache__/agent.cpython-311.pyc +0 -0
- __pycache__/agent2.cpython-311.pyc +0 -0
- agent.py +1 -132
- agent2.py +154 -288
- app.py +2 -2
- appasync.py +210 -0
- custom_models.py +404 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
.vscode/launch.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python Debugger: Current File",
|
9 |
+
"type": "debugpy",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "${file}",
|
12 |
+
"console": "integratedTerminal"
|
13 |
+
}
|
14 |
+
]
|
15 |
+
}
|
__pycache__/agent.cpython-311.pyc
ADDED
Binary file (49.5 kB). View file
|
|
__pycache__/agent2.cpython-311.pyc
ADDED
Binary file (19.7 kB). View file
|
|
agent.py
CHANGED
@@ -10,6 +10,7 @@ import asyncio
|
|
10 |
# Third-party imports
|
11 |
import requests
|
12 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
13 |
|
14 |
# LlamaIndex core imports
|
15 |
from llama_index.core import VectorStoreIndex, Document, Settings
|
@@ -120,140 +121,8 @@ def initialize_models(use_api_mode=False):
|
|
120 |
print("Initializing models in non-API mode with local models...")
|
121 |
|
122 |
try :
|
123 |
-
from typing import Optional, List, Any
|
124 |
-
from pydantic import Field, PrivateAttr
|
125 |
-
from llama_index.core.llms import CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata
|
126 |
-
from llama_index.core.llms.callbacks import llm_completion_callback
|
127 |
-
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
128 |
-
from qwen_vl_utils import process_vision_info
|
129 |
-
import torch
|
130 |
-
|
131 |
-
class QwenVL7BCustomLLM(CustomLLM):
|
132 |
-
model_name: str = Field(default="Qwen/Qwen2.5-VL-7B-Instruct")
|
133 |
-
context_window: int = Field(default=32768)
|
134 |
-
num_output: int = Field(default=256)
|
135 |
-
_model = PrivateAttr()
|
136 |
-
_processor = PrivateAttr()
|
137 |
-
|
138 |
-
def __init__(self, **kwargs):
|
139 |
-
super().__init__(**kwargs)
|
140 |
-
self._model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
141 |
-
self.model_name, torch_dtype=torch.bfloat16, device_map='balanced'
|
142 |
-
)
|
143 |
-
self._processor = AutoProcessor.from_pretrained(self.model_name)
|
144 |
-
|
145 |
-
@property
|
146 |
-
def metadata(self) -> LLMMetadata:
|
147 |
-
return LLMMetadata(
|
148 |
-
context_window=self.context_window,
|
149 |
-
num_output=self.num_output,
|
150 |
-
model_name=self.model_name,
|
151 |
-
)
|
152 |
-
|
153 |
-
@llm_completion_callback()
|
154 |
-
def complete(
|
155 |
-
self,
|
156 |
-
prompt: str,
|
157 |
-
image_paths: Optional[List[str]] = None,
|
158 |
-
**kwargs: Any
|
159 |
-
) -> CompletionResponse:
|
160 |
-
# Prepare multimodal input
|
161 |
-
messages = [{"role": "user", "content": []}]
|
162 |
-
if image_paths:
|
163 |
-
for path in image_paths:
|
164 |
-
messages[0]["content"].append({"type": "image", "image": path})
|
165 |
-
messages[0]["content"].append({"type": "text", "text": prompt})
|
166 |
-
|
167 |
-
# Tokenize and process
|
168 |
-
text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
169 |
-
image_inputs, video_inputs = process_vision_info(messages)
|
170 |
-
inputs = self._processor(
|
171 |
-
text=[text],
|
172 |
-
images=image_inputs,
|
173 |
-
videos=video_inputs,
|
174 |
-
padding=True,
|
175 |
-
return_tensors="pt",
|
176 |
-
)
|
177 |
-
inputs = inputs.to(self._model.device)
|
178 |
-
|
179 |
-
# Generate output
|
180 |
-
generated_ids = self._model.generate(**inputs, max_new_tokens=self.num_output)
|
181 |
-
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
182 |
-
output_text = self._processor.batch_decode(
|
183 |
-
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
184 |
-
)[0]
|
185 |
-
return CompletionResponse(text=output_text)
|
186 |
-
|
187 |
-
@llm_completion_callback()
|
188 |
-
def stream_complete(
|
189 |
-
self,
|
190 |
-
prompt: str,
|
191 |
-
image_paths: Optional[List[str]] = None,
|
192 |
-
**kwargs: Any
|
193 |
-
) -> CompletionResponseGen:
|
194 |
-
response = self.complete(prompt, image_paths)
|
195 |
-
for token in response.text:
|
196 |
-
yield CompletionResponse(text=token, delta=token)
|
197 |
-
|
198 |
-
|
199 |
proj_llm = QwenVL7BCustomLLM()
|
200 |
|
201 |
-
from typing import Any, List, Optional
|
202 |
-
from llama_index.core.embeddings import BaseEmbedding
|
203 |
-
from sentence_transformers import SentenceTransformer
|
204 |
-
from PIL import Image
|
205 |
-
|
206 |
-
class MultimodalCLIPEmbedding(BaseEmbedding):
|
207 |
-
"""
|
208 |
-
Custom embedding class using CLIP for multimodal capabilities.
|
209 |
-
"""
|
210 |
-
|
211 |
-
def __init__(self, model_name: str = "clip-ViT-B-32", **kwargs: Any) -> None:
|
212 |
-
super().__init__(**kwargs)
|
213 |
-
self._model = SentenceTransformer(model_name)
|
214 |
-
|
215 |
-
@classmethod
|
216 |
-
def class_name(cls) -> str:
|
217 |
-
return "multimodal_clip"
|
218 |
-
|
219 |
-
def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
220 |
-
if image_path:
|
221 |
-
image = Image.open(image_path)
|
222 |
-
embedding = self._model.encode(image)
|
223 |
-
return embedding.tolist()
|
224 |
-
else:
|
225 |
-
embedding = self._model.encode(query)
|
226 |
-
return embedding.tolist()
|
227 |
-
|
228 |
-
def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
229 |
-
if image_path:
|
230 |
-
image = Image.open(image_path)
|
231 |
-
embedding = self._model.encode(image)
|
232 |
-
return embedding.tolist()
|
233 |
-
else:
|
234 |
-
embedding = self._model.encode(text)
|
235 |
-
return embedding.tolist()
|
236 |
-
|
237 |
-
def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
|
238 |
-
embeddings = []
|
239 |
-
image_paths = image_paths or [None] * len(texts)
|
240 |
-
|
241 |
-
for text, img_path in zip(texts, image_paths):
|
242 |
-
if img_path:
|
243 |
-
image = Image.open(img_path)
|
244 |
-
emb = self._model.encode(image)
|
245 |
-
else:
|
246 |
-
emb = self._model.encode(text)
|
247 |
-
embeddings.append(emb.tolist())
|
248 |
-
|
249 |
-
return embeddings
|
250 |
-
|
251 |
-
async def _aget_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
252 |
-
return self._get_query_embedding(query, image_path)
|
253 |
-
|
254 |
-
async def _aget_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
255 |
-
return self._get_text_embedding(text, image_path)
|
256 |
-
|
257 |
|
258 |
embed_model = MultimodalCLIPEmbedding()
|
259 |
embed_model.max_seq_length = 1024
|
|
|
10 |
# Third-party imports
|
11 |
import requests
|
12 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
13 |
+
from custom_models import QwenVL7BCustomLLM, BaaiMultimodalEmbedding
|
14 |
|
15 |
# LlamaIndex core imports
|
16 |
from llama_index.core import VectorStoreIndex, Document, Settings
|
|
|
121 |
print("Initializing models in non-API mode with local models...")
|
122 |
|
123 |
try :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
proj_llm = QwenVL7BCustomLLM()
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
embed_model = MultimodalCLIPEmbedding()
|
128 |
embed_model.max_seq_length = 1024
|
agent2.py
CHANGED
@@ -5,16 +5,8 @@ from typing import Dict, Any, List
|
|
5 |
from langchain.docstore.document import Document
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from langchain_community.retrievers import BM25Retriever
|
8 |
-
from smolagents import CodeAgent, OpenAIServerModel,
|
9 |
-
from smolagents
|
10 |
-
from smolagents.agents import ActionStep
|
11 |
-
from selenium import webdriver
|
12 |
-
from selenium.webdriver.common.by import By
|
13 |
-
from selenium.webdriver.common.keys import Keys
|
14 |
-
import helium
|
15 |
-
from PIL import Image
|
16 |
-
from io import BytesIO
|
17 |
-
from time import sleep
|
18 |
|
19 |
# Langfuse observability imports
|
20 |
from opentelemetry.sdk.trace import TracerProvider
|
@@ -22,10 +14,101 @@ from openinference.instrumentation.smolagents import SmolagentsInstrumentor
|
|
22 |
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
23 |
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
24 |
from opentelemetry import trace
|
25 |
-
from opentelemetry.trace import format_trace_id
|
26 |
from langfuse import Langfuse
|
|
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
class BM25RetrieverTool(Tool):
|
30 |
"""
|
31 |
BM25 retriever tool for document search when text documents are available
|
@@ -59,126 +142,6 @@ class BM25RetrieverTool(Tool):
|
|
59 |
for i, doc in enumerate(docs)
|
60 |
])
|
61 |
|
62 |
-
|
63 |
-
@tool
|
64 |
-
def search_item_ctrl_f(text: str, nth_result: int = 1) -> str:
|
65 |
-
"""Search for text on the current page via Ctrl + F and jump to the nth occurrence.
|
66 |
-
|
67 |
-
Args:
|
68 |
-
text: The text string to search for on the webpage
|
69 |
-
nth_result: Which occurrence to jump to (default is 1 for first occurrence)
|
70 |
-
|
71 |
-
Returns:
|
72 |
-
str: Result of the search operation with match count and navigation status
|
73 |
-
"""
|
74 |
-
try:
|
75 |
-
driver = helium.get_driver()
|
76 |
-
elements = driver.find_elements(By.XPATH, f"//*[contains(text(), '{text}')]")
|
77 |
-
if nth_result > len(elements):
|
78 |
-
return f"Match n°{nth_result} not found (only {len(elements)} matches found)"
|
79 |
-
result = f"Found {len(elements)} matches for '{text}'."
|
80 |
-
elem = elements[nth_result - 1]
|
81 |
-
driver.execute_script("arguments[0].scrollIntoView(true);", elem)
|
82 |
-
result += f"Focused on element {nth_result} of {len(elements)}"
|
83 |
-
return result
|
84 |
-
except Exception as e:
|
85 |
-
return f"Error searching for text: {e}"
|
86 |
-
|
87 |
-
|
88 |
-
@tool
|
89 |
-
def go_back() -> str:
|
90 |
-
"""Navigate back to the previous page in browser history.
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
str: Confirmation message or error description
|
94 |
-
"""
|
95 |
-
try:
|
96 |
-
driver = helium.get_driver()
|
97 |
-
driver.back()
|
98 |
-
return "Navigated back to previous page"
|
99 |
-
except Exception as e:
|
100 |
-
return f"Error going back: {e}"
|
101 |
-
|
102 |
-
|
103 |
-
@tool
|
104 |
-
def close_popups() -> str:
|
105 |
-
"""Close any visible modal or pop-up on the page by sending ESC key.
|
106 |
-
|
107 |
-
Returns:
|
108 |
-
str: Confirmation message or error description
|
109 |
-
"""
|
110 |
-
try:
|
111 |
-
driver = helium.get_driver()
|
112 |
-
webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform()
|
113 |
-
return "Attempted to close popups"
|
114 |
-
except Exception as e:
|
115 |
-
return f"Error closing popups: {e}"
|
116 |
-
|
117 |
-
|
118 |
-
@tool
|
119 |
-
def scroll_page(direction: str = "down", amount: int = 3) -> str:
|
120 |
-
"""Scroll the webpage in the specified direction.
|
121 |
-
|
122 |
-
Args:
|
123 |
-
direction: Direction to scroll, either 'up' or 'down'
|
124 |
-
amount: Number of scroll actions to perform
|
125 |
-
|
126 |
-
Returns:
|
127 |
-
str: Confirmation message or error description
|
128 |
-
"""
|
129 |
-
try:
|
130 |
-
driver = helium.get_driver()
|
131 |
-
for _ in range(amount):
|
132 |
-
if direction.lower() == "down":
|
133 |
-
driver.execute_script("window.scrollBy(0, 300);")
|
134 |
-
elif direction.lower() == "up":
|
135 |
-
driver.execute_script("window.scrollBy(0, -300);")
|
136 |
-
sleep(0.5)
|
137 |
-
return f"Scrolled {direction} {amount} times"
|
138 |
-
except Exception as e:
|
139 |
-
return f"Error scrolling: {e}"
|
140 |
-
|
141 |
-
|
142 |
-
@tool
|
143 |
-
def get_page_text() -> str:
|
144 |
-
"""Extract all visible text from the current webpage.
|
145 |
-
|
146 |
-
Returns:
|
147 |
-
str: The visible text content of the page
|
148 |
-
"""
|
149 |
-
try:
|
150 |
-
driver = helium.get_driver()
|
151 |
-
text = driver.find_element(By.TAG_NAME, "body").text
|
152 |
-
return f"Page text (first 2000 chars): {text[:2000]}"
|
153 |
-
except Exception as e:
|
154 |
-
return f"Error getting page text: {e}"
|
155 |
-
|
156 |
-
|
157 |
-
def save_screenshot_callback(memory_step: ActionStep, agent: CodeAgent) -> None:
|
158 |
-
"""Save screenshots for web browser automation"""
|
159 |
-
try:
|
160 |
-
sleep(1.0)
|
161 |
-
driver = helium.get_driver()
|
162 |
-
if driver is not None:
|
163 |
-
# Clean up old screenshots
|
164 |
-
for previous_memory_step in agent.memory.steps:
|
165 |
-
if isinstance(previous_memory_step, ActionStep) and previous_memory_step.step_number <= memory_step.step_number - 2:
|
166 |
-
previous_memory_step.observations_images = None
|
167 |
-
|
168 |
-
png_bytes = driver.get_screenshot_as_png()
|
169 |
-
image = Image.open(BytesIO(png_bytes))
|
170 |
-
memory_step.observations_images = [image.copy()]
|
171 |
-
|
172 |
-
# Update observations with current URL
|
173 |
-
url_info = f"Current url: {driver.current_url}"
|
174 |
-
memory_step.observations = (
|
175 |
-
url_info if memory_step.observations is None
|
176 |
-
else memory_step.observations + "\n" + url_info
|
177 |
-
)
|
178 |
-
except Exception as e:
|
179 |
-
print(f"Error in screenshot callback: {e}")
|
180 |
-
|
181 |
-
|
182 |
class GAIAAgent:
|
183 |
"""
|
184 |
GAIA agent using smolagents with Gemini 2.0 Flash and Langfuse observability
|
@@ -200,6 +163,8 @@ class GAIAAgent:
|
|
200 |
model_id="gemini-2.0-flash",
|
201 |
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
|
202 |
api_key=gemini_api_key,
|
|
|
|
|
203 |
)
|
204 |
|
205 |
# Store user and session IDs for tracking
|
@@ -207,26 +172,17 @@ class GAIAAgent:
|
|
207 |
self.session_id = session_id or "gaia-session"
|
208 |
|
209 |
# GAIA system prompt from the leaderboard
|
210 |
-
self.system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts and
|
211 |
-
|
212 |
-
|
213 |
-
-
|
214 |
-
-
|
215 |
-
-
|
216 |
-
|
217 |
-
|
218 |
-
For document retrieval:
|
219 |
-
- Use the BM25 retriever when there are text documents attached
|
220 |
-
- Search with relevant keywords from the question
|
221 |
-
|
222 |
-
Your final answer should be as few words as possible, a number, or a comma-separated list. Don't use articles, abbreviations, or units unless specified."""
|
223 |
|
224 |
# Initialize retriever tool (will be updated when documents are loaded)
|
225 |
self.retriever_tool = BM25RetrieverTool()
|
226 |
|
227 |
-
# Initialize web driver for browser automation
|
228 |
-
self.driver = None
|
229 |
-
|
230 |
# Create the agent
|
231 |
self.agent = None
|
232 |
self._create_agent()
|
@@ -234,6 +190,13 @@ Your final answer should be as few words as possible, a number, or a comma-separ
|
|
234 |
# Initialize Langfuse client
|
235 |
self.langfuse = Langfuse()
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
def _setup_langfuse_observability(self):
|
238 |
"""Set up Langfuse observability with OpenTelemetry"""
|
239 |
# Get Langfuse keys from environment variables
|
@@ -271,48 +234,17 @@ Your final answer should be as few words as possible, a number, or a comma-separ
|
|
271 |
"""Create the CodeAgent with tools"""
|
272 |
base_tools = [
|
273 |
self.retriever_tool,
|
274 |
-
|
275 |
-
go_back,
|
276 |
-
close_popups,
|
277 |
-
scroll_page,
|
278 |
-
get_page_text
|
279 |
]
|
280 |
-
|
281 |
self.agent = CodeAgent(
|
282 |
-
tools=base_tools
|
|
|
|
|
|
|
283 |
model=self.model,
|
284 |
-
|
285 |
-
|
286 |
-
additional_authorized_imports=["helium", "requests", "BeautifulSoup", "json"],
|
287 |
-
step_callbacks=[save_screenshot_callback] if self.driver else [],
|
288 |
-
max_steps=5,
|
289 |
-
description=self.system_prompt,
|
290 |
-
verbosity_level=2,
|
291 |
-
)
|
292 |
-
|
293 |
-
def initialize_browser(self):
|
294 |
-
"""Initialize browser for web automation tasks"""
|
295 |
-
try:
|
296 |
-
chrome_options = webdriver.ChromeOptions()
|
297 |
-
chrome_options.add_argument("--force-device-scale-factor=1")
|
298 |
-
chrome_options.add_argument("--window-size=1000,1350")
|
299 |
-
chrome_options.add_argument("--disable-pdf-viewer")
|
300 |
-
chrome_options.add_argument("--window-position=0,0")
|
301 |
-
chrome_options.add_argument("--no-sandbox")
|
302 |
-
chrome_options.add_argument("--disable-dev-shm-usage")
|
303 |
-
|
304 |
-
self.driver = helium.start_chrome(headless=False, options=chrome_options)
|
305 |
-
|
306 |
-
# Recreate agent with browser tools
|
307 |
-
self._create_agent()
|
308 |
-
|
309 |
-
# Import helium for the agent
|
310 |
-
self.agent.python_executor("from helium import *")
|
311 |
|
312 |
-
return True
|
313 |
-
except Exception as e:
|
314 |
-
print(f"Failed to initialize browser: {e}")
|
315 |
-
return False
|
316 |
|
317 |
def load_documents_from_file(self, file_path: str):
|
318 |
"""Load and process documents from a file for BM25 retrieval"""
|
@@ -375,35 +307,22 @@ Your final answer should be as few words as possible, a number, or a comma-separ
|
|
375 |
if task_id:
|
376 |
trace_tags.append(f"task-{task_id}")
|
377 |
|
378 |
-
#
|
379 |
-
with self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
try:
|
381 |
-
# Set
|
382 |
-
span.
|
383 |
-
span.set_attribute("langfuse.session.id", self.session_id)
|
384 |
-
span.set_attribute("langfuse.tags", trace_tags)
|
385 |
-
span.set_attribute("gaia.task_id", task_id)
|
386 |
-
span.set_attribute("gaia.question_length", len(question))
|
387 |
-
|
388 |
-
# Get trace ID for Langfuse linking
|
389 |
-
current_span = trace.get_current_span()
|
390 |
-
span_context = current_span.get_span_context()
|
391 |
-
trace_id = span_context.trace_id
|
392 |
-
formatted_trace_id = format_trace_id(trace_id)
|
393 |
-
|
394 |
-
# Create Langfuse trace
|
395 |
-
langfuse_trace = self.langfuse.trace(
|
396 |
-
id=formatted_trace_id,
|
397 |
-
name="GAIA Question Solving",
|
398 |
-
input={"question": question, "task_id": task_id},
|
399 |
user_id=self.user_id,
|
400 |
session_id=self.session_id,
|
401 |
-
tags=trace_tags
|
402 |
-
metadata={
|
403 |
-
"model": self.model.model_id,
|
404 |
-
"question_length": len(question),
|
405 |
-
"has_file": bool(task_id)
|
406 |
-
}
|
407 |
)
|
408 |
|
409 |
# Download and load file if task_id provided
|
@@ -412,47 +331,22 @@ Your final answer should be as few words as possible, a number, or a comma-separ
|
|
412 |
file_path = self.download_gaia_file(task_id)
|
413 |
if file_path:
|
414 |
file_loaded = self.load_documents_from_file(file_path)
|
415 |
-
span.set_attribute("gaia.file_loaded", file_loaded)
|
416 |
print(f"Loaded file for task {task_id}")
|
417 |
|
418 |
-
# Check if this requires web browsing
|
419 |
-
web_indicators = ["navigate", "browser", "website", "webpage", "url", "click", "search on"]
|
420 |
-
needs_browser = any(indicator in question.lower() for indicator in web_indicators)
|
421 |
-
span.set_attribute("gaia.needs_browser", needs_browser)
|
422 |
-
|
423 |
-
if needs_browser and not self.driver:
|
424 |
-
print("Initializing browser for web automation...")
|
425 |
-
browser_initialized = self.initialize_browser()
|
426 |
-
span.set_attribute("gaia.browser_initialized", browser_initialized)
|
427 |
-
|
428 |
# Prepare the prompt
|
429 |
prompt = f"""
|
430 |
-
Question: {question}
|
431 |
-
{f'Task ID: {task_id}' if task_id else ''}
|
432 |
-
{f'File loaded: Yes' if file_loaded else 'File loaded: No'}
|
433 |
|
434 |
-
Solve this step by step. Use the available tools to gather information and provide a precise answer.
|
435 |
"""
|
436 |
|
437 |
-
if needs_browser:
|
438 |
-
prompt += "\n" + helium_instructions
|
439 |
-
|
440 |
print("=== AGENT REASONING ===")
|
441 |
result = self.agent.run(prompt)
|
442 |
print("=== END REASONING ===")
|
443 |
|
444 |
-
# Update
|
445 |
-
|
446 |
-
output={"answer": str(result)},
|
447 |
-
end_time=None # Will be set automatically
|
448 |
-
)
|
449 |
-
|
450 |
-
# Add success attributes
|
451 |
-
span.set_attribute("gaia.success", True)
|
452 |
-
span.set_attribute("gaia.answer_length", len(str(result)))
|
453 |
-
|
454 |
-
# Flush Langfuse data
|
455 |
-
self.langfuse.flush()
|
456 |
|
457 |
return str(result)
|
458 |
|
@@ -460,26 +354,14 @@ Solve this step by step. Use the available tools to gather information and provi
|
|
460 |
error_msg = f"Error processing question: {str(e)}"
|
461 |
print(error_msg)
|
462 |
|
463 |
-
# Log error
|
464 |
-
span.
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
langfuse_trace.update(
|
469 |
-
output={"error": error_msg},
|
470 |
-
level="ERROR"
|
471 |
-
)
|
472 |
|
473 |
-
self.langfuse.flush()
|
474 |
return error_msg
|
475 |
-
|
476 |
-
finally:
|
477 |
-
# Clean up browser if initialized
|
478 |
-
if self.driver:
|
479 |
-
try:
|
480 |
-
helium.kill_browser()
|
481 |
-
except:
|
482 |
-
pass
|
483 |
|
484 |
def evaluate_answer(self, question: str, answer: str, expected_answer: str = None) -> Dict[str, Any]:
|
485 |
"""
|
@@ -506,29 +388,20 @@ Provide your rating as JSON: {{"accuracy": X, "completeness": Y, "clarity": Z, "
|
|
506 |
|
507 |
# Try to parse JSON response
|
508 |
import json
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
return {
|
515 |
-
"accuracy": 3,
|
516 |
-
"completeness": 3,
|
517 |
-
"clarity": 3,
|
518 |
-
"overall": 3,
|
519 |
-
"reasoning": "Could not parse evaluation response",
|
520 |
-
"raw_evaluation": evaluation_result
|
521 |
-
}
|
522 |
-
|
523 |
-
except Exception as e:
|
524 |
return {
|
525 |
-
"accuracy":
|
526 |
-
"completeness":
|
527 |
-
"clarity":
|
528 |
-
"overall":
|
529 |
-
"reasoning":
|
530 |
}
|
531 |
|
|
|
532 |
def add_user_feedback(self, trace_id: str, feedback_score: int, comment: str = None):
|
533 |
"""
|
534 |
Add user feedback to a specific trace
|
@@ -566,7 +439,7 @@ if __name__ == "__main__":
|
|
566 |
|
567 |
# Example question
|
568 |
question_data = {
|
569 |
-
"Question": "How many studio albums Mercedes Sosa has published between 2000-2009?",
|
570 |
"task_id": ""
|
571 |
}
|
572 |
|
@@ -575,11 +448,4 @@ if __name__ == "__main__":
|
|
575 |
question_data,
|
576 |
tags=["music-question", "discography"]
|
577 |
)
|
578 |
-
print(f"Answer: {answer}")
|
579 |
-
|
580 |
-
# Evaluate the answer
|
581 |
-
evaluation = agent.evaluate_answer(
|
582 |
-
question_data["Question"],
|
583 |
-
answer
|
584 |
-
)
|
585 |
-
print(f"Evaluation: {evaluation}")
|
|
|
5 |
from langchain.docstore.document import Document
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
from langchain_community.retrievers import BM25Retriever
|
8 |
+
from smolagents import CodeAgent, OpenAIServerModel, Tool
|
9 |
+
from smolagents import PythonInterpreterTool, SpeechToTextTool
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
# Langfuse observability imports
|
12 |
from opentelemetry.sdk.trace import TracerProvider
|
|
|
14 |
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
15 |
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
16 |
from opentelemetry import trace
|
|
|
17 |
from langfuse import Langfuse
|
18 |
+
from smolagents import SpeechToTextTool, PythonInterpreterTool
|
19 |
|
20 |
|
21 |
+
import requests
|
22 |
+
from markdownify import markdownify
|
23 |
+
from requests.exceptions import RequestException
|
24 |
+
from smolagents import tool
|
25 |
+
import re
|
26 |
+
|
27 |
+
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
28 |
+
|
29 |
+
class WebSearchTool(Tool):
|
30 |
+
name = "web_search"
|
31 |
+
description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results."""
|
32 |
+
inputs = {"query": {"type": "string", "description": "The search query to perform."}}
|
33 |
+
output_type = "string"
|
34 |
+
|
35 |
+
def __init__(self, max_results=10, **kwargs):
|
36 |
+
super().__init__()
|
37 |
+
self.max_results = max_results
|
38 |
+
try:
|
39 |
+
from duckduckgo_search import DDGS
|
40 |
+
except ImportError as e:
|
41 |
+
raise ImportError(
|
42 |
+
"You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
|
43 |
+
) from e
|
44 |
+
self.ddgs = DDGS(**kwargs)
|
45 |
+
|
46 |
+
def _perform_search(self, query: str):
|
47 |
+
"""Internal method to perform the actual search."""
|
48 |
+
return self.ddgs.text(query, max_results=self.max_results)
|
49 |
+
|
50 |
+
def forward(self, query: str) -> str:
|
51 |
+
results = []
|
52 |
+
|
53 |
+
# First attempt with timeout
|
54 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
55 |
+
try:
|
56 |
+
future = executor.submit(self._perform_search, query)
|
57 |
+
results = future.result(timeout=30) # 30 second timeout
|
58 |
+
except TimeoutError:
|
59 |
+
print("First search attempt timed out after 30 seconds, retrying...")
|
60 |
+
results = []
|
61 |
+
|
62 |
+
# Retry if no results or timeout occurred
|
63 |
+
if len(results) == 0:
|
64 |
+
print("Retrying search...")
|
65 |
+
with ThreadPoolExecutor(max_workers=1) as executor:
|
66 |
+
try:
|
67 |
+
future = executor.submit(self._perform_search, query)
|
68 |
+
results = future.result(timeout=30) # 30 second timeout for retry
|
69 |
+
except TimeoutError:
|
70 |
+
raise Exception("Search timed out after 30 seconds on both attempts. Try a different query.")
|
71 |
+
|
72 |
+
# Final check for results
|
73 |
+
if len(results) == 0:
|
74 |
+
raise Exception("No results found after two attempts! Try a less restrictive/shorter query.")
|
75 |
+
|
76 |
+
postprocessed_results = [f"[{result['title']}]({result['href']})\n{result['body']}" for result in results]
|
77 |
+
return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
|
78 |
+
|
79 |
+
@tool
|
80 |
+
def visit_webpage(url: str) -> str:
|
81 |
+
"""Visits a webpage at the given URL and returns its content as a markdown string.
|
82 |
+
|
83 |
+
Args:
|
84 |
+
url: The URL of the webpage to visit.
|
85 |
+
|
86 |
+
Returns:
|
87 |
+
The content of the webpage converted to Markdown, or an error message if the request fails.
|
88 |
+
"""
|
89 |
+
try:
|
90 |
+
# Send a GET request to the URL
|
91 |
+
response = requests.get(url)
|
92 |
+
response.raise_for_status() # Raise an exception for bad status codes
|
93 |
+
|
94 |
+
# Parse the content as HTML with BeautifulSoup
|
95 |
+
from bs4 import BeautifulSoup
|
96 |
+
soup = BeautifulSoup(response.content, 'html.parser')
|
97 |
+
# Extract text and convert to Markdown
|
98 |
+
content = soup.get_text(separator="\n", strip=True)
|
99 |
+
markdown_content = markdownify(content)
|
100 |
+
# Clean up the markdown content
|
101 |
+
markdown_content = re.sub(r'\n+', '\n', markdown_content) # Remove excessive newlines
|
102 |
+
markdown_content = re.sub(r'\s+', ' ', markdown_content) # Remove excessive spaces
|
103 |
+
markdown_content = markdown_content.strip() # Strip leading/trailing whitespace
|
104 |
+
return markdown_content
|
105 |
+
|
106 |
+
except RequestException as e:
|
107 |
+
return f"Error fetching the webpage: {str(e)}"
|
108 |
+
except Exception as e:
|
109 |
+
return f"An unexpected error occurred: {str(e)}"
|
110 |
+
|
111 |
+
|
112 |
class BM25RetrieverTool(Tool):
|
113 |
"""
|
114 |
BM25 retriever tool for document search when text documents are available
|
|
|
142 |
for i, doc in enumerate(docs)
|
143 |
])
|
144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
class GAIAAgent:
|
146 |
"""
|
147 |
GAIA agent using smolagents with Gemini 2.0 Flash and Langfuse observability
|
|
|
163 |
model_id="gemini-2.0-flash",
|
164 |
api_base="https://generativelanguage.googleapis.com/v1beta/openai/",
|
165 |
api_key=gemini_api_key,
|
166 |
+
temperature=0.0,
|
167 |
+
top_p=1.0,
|
168 |
)
|
169 |
|
170 |
# Store user and session IDs for tracking
|
|
|
172 |
self.session_id = session_id or "gaia-session"
|
173 |
|
174 |
# GAIA system prompt from the leaderboard
|
175 |
+
self.system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
176 |
+
|
177 |
+
IMPORTANT :
|
178 |
+
- When you need to find information in a document, use the BM25 retriever tool to search for relevant sections.
|
179 |
+
- When you need to find information in a visited web page, do not use the BM25 retriever tool, but instead use the visit_webpage tool to fetch the content of the page, and then use the retrieved content to answer the question.
|
180 |
+
- In the last step of your reasoning, if you think your reasoning is not able to answer the question, answer the question directy with your internal reasoning, without using the BM25 retriever tool or the visit_webpage tool.
|
181 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
# Initialize retriever tool (will be updated when documents are loaded)
|
184 |
self.retriever_tool = BM25RetrieverTool()
|
185 |
|
|
|
|
|
|
|
186 |
# Create the agent
|
187 |
self.agent = None
|
188 |
self._create_agent()
|
|
|
190 |
# Initialize Langfuse client
|
191 |
self.langfuse = Langfuse()
|
192 |
|
193 |
+
from langfuse import get_client
|
194 |
+
self.langfuse = get_client() # ✅ Use get_client() for v3
|
195 |
+
|
196 |
+
# Store user and session IDs for tracking
|
197 |
+
self.user_id = user_id or "gaia-user"
|
198 |
+
self.session_id = session_id or "gaia-session"
|
199 |
+
|
200 |
def _setup_langfuse_observability(self):
|
201 |
"""Set up Langfuse observability with OpenTelemetry"""
|
202 |
# Get Langfuse keys from environment variables
|
|
|
234 |
"""Create the CodeAgent with tools"""
|
235 |
base_tools = [
|
236 |
self.retriever_tool,
|
237 |
+
visit_webpage,
|
|
|
|
|
|
|
|
|
238 |
]
|
|
|
239 |
self.agent = CodeAgent(
|
240 |
+
tools=base_tools + [
|
241 |
+
SpeechToTextTool(),
|
242 |
+
WebSearchTool(),
|
243 |
+
PythonInterpreterTool()],
|
244 |
model=self.model,
|
245 |
+
description=self.system_prompt,
|
246 |
+
max_steps=6 )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
|
|
|
|
|
|
|
|
248 |
|
249 |
def load_documents_from_file(self, file_path: str):
|
250 |
"""Load and process documents from a file for BM25 retrieval"""
|
|
|
307 |
if task_id:
|
308 |
trace_tags.append(f"task-{task_id}")
|
309 |
|
310 |
+
# Use SDK v3 context manager approach
|
311 |
+
with self.langfuse.start_as_current_span(
|
312 |
+
name="GAIA-Question-Solving",
|
313 |
+
input={"question": question, "task_id": task_id},
|
314 |
+
metadata={
|
315 |
+
"model": self.model.model_id,
|
316 |
+
"question_length": len(question),
|
317 |
+
"has_file": bool(task_id)
|
318 |
+
}
|
319 |
+
) as span:
|
320 |
try:
|
321 |
+
# Set trace attributes using v3 syntax
|
322 |
+
span.update_trace(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
user_id=self.user_id,
|
324 |
session_id=self.session_id,
|
325 |
+
tags=trace_tags
|
|
|
|
|
|
|
|
|
|
|
326 |
)
|
327 |
|
328 |
# Download and load file if task_id provided
|
|
|
331 |
file_path = self.download_gaia_file(task_id)
|
332 |
if file_path:
|
333 |
file_loaded = self.load_documents_from_file(file_path)
|
|
|
334 |
print(f"Loaded file for task {task_id}")
|
335 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
# Prepare the prompt
|
337 |
prompt = f"""
|
338 |
+
Question: {question}
|
339 |
+
{f'Task ID: {task_id}' if task_id else ''}
|
340 |
+
{f'File loaded: Yes' if file_loaded else 'File loaded: No'}
|
341 |
|
|
|
342 |
"""
|
343 |
|
|
|
|
|
|
|
344 |
print("=== AGENT REASONING ===")
|
345 |
result = self.agent.run(prompt)
|
346 |
print("=== END REASONING ===")
|
347 |
|
348 |
+
# Update span with result using v3 syntax
|
349 |
+
span.update(output={"answer": str(result)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
return str(result)
|
352 |
|
|
|
354 |
error_msg = f"Error processing question: {str(e)}"
|
355 |
print(error_msg)
|
356 |
|
357 |
+
# Log error using v3 syntax
|
358 |
+
span.update(
|
359 |
+
output={"error": error_msg},
|
360 |
+
level="ERROR"
|
361 |
+
)
|
|
|
|
|
|
|
|
|
362 |
|
|
|
363 |
return error_msg
|
364 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
366 |
def evaluate_answer(self, question: str, answer: str, expected_answer: str = None) -> Dict[str, Any]:
|
367 |
"""
|
|
|
388 |
|
389 |
# Try to parse JSON response
|
390 |
import json
|
391 |
+
scores = json.loads(evaluation_result)
|
392 |
+
return scores
|
393 |
+
except json.JSONDecodeError:
|
394 |
+
# If JSON parsing fails, return a default structure
|
395 |
+
print("Failed to parse evaluation result as JSON. Returning default scores.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
return {
|
397 |
+
"accuracy": 0,
|
398 |
+
"completeness": 0,
|
399 |
+
"clarity": 0,
|
400 |
+
"overall": 0,
|
401 |
+
"reasoning": "Could not parse evaluation result"
|
402 |
}
|
403 |
|
404 |
+
|
405 |
def add_user_feedback(self, trace_id: str, feedback_score: int, comment: str = None):
|
406 |
"""
|
407 |
Add user feedback to a specific trace
|
|
|
439 |
|
440 |
# Example question
|
441 |
question_data = {
|
442 |
+
"Question": "How many studio albums Mercedes Sosa has published between 2000-2009? Search on the English Wikipedia webpage.",
|
443 |
"task_id": ""
|
444 |
}
|
445 |
|
|
|
448 |
question_data,
|
449 |
tags=["music-question", "discography"]
|
450 |
)
|
451 |
+
print(f"Answer: {answer}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -10,7 +10,7 @@ import pandas as pd
|
|
10 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
11 |
|
12 |
# Import your custom agent from agent.py
|
13 |
-
from
|
14 |
|
15 |
# --- Basic Agent Definition ---
|
16 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
@@ -18,7 +18,7 @@ class BasicAgent:
|
|
18 |
def __init__(self):
|
19 |
print("BasicAgent initialized.")
|
20 |
# Initialize your enhanced GAIA agent
|
21 |
-
self.gaia_agent =
|
22 |
|
23 |
async def __call__(self, question: str) -> str:
|
24 |
try:
|
|
|
10 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
11 |
|
12 |
# Import your custom agent from agent.py
|
13 |
+
from agent2 import GAIAAgent
|
14 |
|
15 |
# --- Basic Agent Definition ---
|
16 |
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
|
|
18 |
def __init__(self):
|
19 |
print("BasicAgent initialized.")
|
20 |
# Initialize your enhanced GAIA agent
|
21 |
+
self.gaia_agent = GAIAAgent()
|
22 |
|
23 |
async def __call__(self, question: str) -> str:
|
24 |
try:
|
appasync.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import requests
|
4 |
+
import inspect
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
# (Keep Constants as is)
|
9 |
+
# --- Constants ---
|
10 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
11 |
+
|
12 |
+
# Import your custom agent from agent.py
|
13 |
+
from agent2 import GAIAAgent
|
14 |
+
|
15 |
+
# --- Basic Agent Definition ---
|
16 |
+
# ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
|
17 |
+
class BasicAgent:
|
18 |
+
def __init__(self):
|
19 |
+
print("BasicAgent initialized.")
|
20 |
+
# Initialize your enhanced GAIA agent
|
21 |
+
self.gaia_agent = GAIAAgent()
|
22 |
+
|
23 |
+
def __call__(self, question: str) -> str:
|
24 |
+
try:
|
25 |
+
question_data = {
|
26 |
+
"Question": question,
|
27 |
+
"task_id": "basic_agent_task"
|
28 |
+
}
|
29 |
+
answer = self.gaia_agent.solve_gaia_question(question_data)
|
30 |
+
return str(answer)
|
31 |
+
except Exception as e:
|
32 |
+
return e
|
33 |
+
|
34 |
+
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
35 |
+
"""
|
36 |
+
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
37 |
+
and displays the results.
|
38 |
+
"""
|
39 |
+
# --- Determine HF Space Runtime URL and Repo URL ---
|
40 |
+
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
41 |
+
|
42 |
+
if profile:
|
43 |
+
username= f"{profile.username}"
|
44 |
+
print(f"User logged in: {username}")
|
45 |
+
else:
|
46 |
+
print("User not logged in.")
|
47 |
+
return "Please Login to Hugging Face with the button.", None
|
48 |
+
|
49 |
+
api_url = DEFAULT_API_URL
|
50 |
+
questions_url = f"{api_url}/questions"
|
51 |
+
submit_url = f"{api_url}/submit"
|
52 |
+
|
53 |
+
# 1. Instantiate Agent ( modify this part to create your agent)
|
54 |
+
try:
|
55 |
+
agent = BasicAgent()
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Error instantiating agent: {e}")
|
58 |
+
return f"Error initializing agent: {e}", None
|
59 |
+
# In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
|
60 |
+
#agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
61 |
+
agent_code = "Running on Kaggle"
|
62 |
+
print(agent_code)
|
63 |
+
|
64 |
+
# 2. Fetch Questions
|
65 |
+
print(f"Fetching questions from: {questions_url}")
|
66 |
+
try:
|
67 |
+
response = requests.get(questions_url, timeout=15)
|
68 |
+
response.raise_for_status()
|
69 |
+
questions_data = response.json()
|
70 |
+
if not questions_data:
|
71 |
+
print("Fetched questions list is empty.")
|
72 |
+
return "Fetched questions list is empty or invalid format.", None
|
73 |
+
print(f"Fetched {len(questions_data)} questions.")
|
74 |
+
except requests.exceptions.RequestException as e:
|
75 |
+
print(f"Error fetching questions: {e}")
|
76 |
+
return f"Error fetching questions: {e}", None
|
77 |
+
except requests.exceptions.JSONDecodeError as e:
|
78 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
79 |
+
print(f"Response text: {response.text[:500]}")
|
80 |
+
return f"Error decoding server response for questions: {e}", None
|
81 |
+
except Exception as e:
|
82 |
+
print(f"An unexpected error occurred fetching questions: {e}")
|
83 |
+
return f"An unexpected error occurred fetching questions: {e}", None
|
84 |
+
|
85 |
+
# 3. Run your Agent
|
86 |
+
results_log = []
|
87 |
+
answers_payload = []
|
88 |
+
print(f"Running agent on {len(questions_data)} questions...")
|
89 |
+
for item in questions_data:
|
90 |
+
task_id = item.get("task_id")
|
91 |
+
question_text = item.get("question")
|
92 |
+
print(question_text)
|
93 |
+
if not task_id or question_text is None:
|
94 |
+
print(f"Skipping item with missing task_id or question: {item}")
|
95 |
+
continue
|
96 |
+
try:
|
97 |
+
submitted_answer = agent(question_text)
|
98 |
+
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
|
99 |
+
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
|
100 |
+
except Exception as e:
|
101 |
+
print(f"Error running agent on task {task_id}: {e}")
|
102 |
+
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
|
103 |
+
|
104 |
+
if not answers_payload:
|
105 |
+
print("Agent did not produce any answers to submit.")
|
106 |
+
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
107 |
+
|
108 |
+
# 4. Prepare Submission
|
109 |
+
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
|
110 |
+
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
111 |
+
print(status_update)
|
112 |
+
|
113 |
+
# 5. Submit
|
114 |
+
print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
|
115 |
+
try:
|
116 |
+
response = requests.post(submit_url, json=submission_data, timeout=60)
|
117 |
+
response.raise_for_status()
|
118 |
+
result_data = response.json()
|
119 |
+
final_status = (
|
120 |
+
f"Submission Successful!\n"
|
121 |
+
f"User: {result_data.get('username')}\n"
|
122 |
+
f"Overall Score: {result_data.get('score', 'N/A')}% "
|
123 |
+
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
|
124 |
+
f"Message: {result_data.get('message', 'No message received.')}"
|
125 |
+
)
|
126 |
+
print("Submission successful.")
|
127 |
+
results_df = pd.DataFrame(results_log)
|
128 |
+
return final_status, results_df
|
129 |
+
except requests.exceptions.HTTPError as e:
|
130 |
+
error_detail = f"Server responded with status {e.response.status_code}."
|
131 |
+
try:
|
132 |
+
error_json = e.response.json()
|
133 |
+
error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
|
134 |
+
except requests.exceptions.JSONDecodeError:
|
135 |
+
error_detail += f" Response: {e.response.text[:500]}"
|
136 |
+
status_message = f"Submission Failed: {error_detail}"
|
137 |
+
print(status_message)
|
138 |
+
results_df = pd.DataFrame(results_log)
|
139 |
+
return status_message, results_df
|
140 |
+
except requests.exceptions.Timeout:
|
141 |
+
status_message = "Submission Failed: The request timed out."
|
142 |
+
print(status_message)
|
143 |
+
results_df = pd.DataFrame(results_log)
|
144 |
+
return status_message, results_df
|
145 |
+
except requests.exceptions.RequestException as e:
|
146 |
+
status_message = f"Submission Failed: Network error - {e}"
|
147 |
+
print(status_message)
|
148 |
+
results_df = pd.DataFrame(results_log)
|
149 |
+
return status_message, results_df
|
150 |
+
except Exception as e:
|
151 |
+
status_message = f"An unexpected error occurred during submission: {e}"
|
152 |
+
print(status_message)
|
153 |
+
results_df = pd.DataFrame(results_log)
|
154 |
+
return status_message, results_df
|
155 |
+
|
156 |
+
|
157 |
+
# --- Build Gradio Interface using Blocks ---
|
158 |
+
with gr.Blocks() as demo:
|
159 |
+
gr.Markdown("# Basic Agent Evaluation Runner")
|
160 |
+
gr.Markdown(
|
161 |
+
"""
|
162 |
+
**Instructions:**
|
163 |
+
|
164 |
+
1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
|
165 |
+
2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
|
166 |
+
3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
|
167 |
+
|
168 |
+
---
|
169 |
+
**Disclaimers:**
|
170 |
+
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions).
|
171 |
+
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async.
|
172 |
+
"""
|
173 |
+
)
|
174 |
+
|
175 |
+
gr.LoginButton()
|
176 |
+
|
177 |
+
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
178 |
+
|
179 |
+
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
|
180 |
+
# Removed max_rows=10 from DataFrame constructor
|
181 |
+
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
182 |
+
|
183 |
+
run_button.click(
|
184 |
+
fn=run_and_submit_all,
|
185 |
+
outputs=[status_output, results_table]
|
186 |
+
)
|
187 |
+
|
188 |
+
if __name__ == "__main__":
|
189 |
+
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
190 |
+
# Check for SPACE_HOST and SPACE_ID at startup for information
|
191 |
+
#space_host_startup = os.getenv("SPACE_HOST")
|
192 |
+
#space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
193 |
+
|
194 |
+
#if space_host_startup:
|
195 |
+
#print(f"✅ SPACE_HOST found: {space_host_startup}")
|
196 |
+
#print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
|
197 |
+
#else:
|
198 |
+
#print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
199 |
+
|
200 |
+
#if space_id_startup: # Print repo URLs if SPACE_ID is found
|
201 |
+
#print(f"✅ SPACE_ID found: {space_id_startup}")
|
202 |
+
#print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
203 |
+
#print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
|
204 |
+
#else:
|
205 |
+
#print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
|
206 |
+
|
207 |
+
print("-"*(60 + len(" App Starting ")) + "\n")
|
208 |
+
|
209 |
+
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
210 |
+
demo.launch(debug=True, share=True)
|
custom_models.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List, Any
|
2 |
+
from pydantic import Field, PrivateAttr
|
3 |
+
from llama_index.core.llms import CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata
|
4 |
+
from llama_index.core.llms.callbacks import llm_completion_callback
|
5 |
+
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
6 |
+
from qwen_vl_utils import process_vision_info
|
7 |
+
import torch
|
8 |
+
from typing import Any, List, Optional
|
9 |
+
from llama_index.core.embeddings import BaseEmbedding
|
10 |
+
from sentence_transformers import SentenceTransformer
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
class QwenVL7BCustomLLM(CustomLLM):
|
14 |
+
model_name: str = Field(default="Qwen/Qwen2.5-VL-7B-Instruct")
|
15 |
+
context_window: int = Field(default=32768)
|
16 |
+
num_output: int = Field(default=256)
|
17 |
+
_model = PrivateAttr()
|
18 |
+
_processor = PrivateAttr()
|
19 |
+
|
20 |
+
def __init__(self, **kwargs):
|
21 |
+
super().__init__(**kwargs)
|
22 |
+
self._model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
23 |
+
self.model_name, torch_dtype=torch.bfloat16, device_map='balanced'
|
24 |
+
)
|
25 |
+
self._processor = AutoProcessor.from_pretrained(self.model_name)
|
26 |
+
|
27 |
+
@property
|
28 |
+
def metadata(self) -> LLMMetadata:
|
29 |
+
return LLMMetadata(
|
30 |
+
context_window=self.context_window,
|
31 |
+
num_output=self.num_output,
|
32 |
+
model_name=self.model_name,
|
33 |
+
)
|
34 |
+
|
35 |
+
@llm_completion_callback()
|
36 |
+
def complete(
|
37 |
+
self,
|
38 |
+
prompt: str,
|
39 |
+
image_paths: Optional[List[str]] = None,
|
40 |
+
**kwargs: Any
|
41 |
+
) -> CompletionResponse:
|
42 |
+
# Prepare multimodal input
|
43 |
+
messages = [{"role": "user", "content": []}]
|
44 |
+
if image_paths:
|
45 |
+
for path in image_paths:
|
46 |
+
messages[0]["content"].append({"type": "image", "image": path})
|
47 |
+
messages[0]["content"].append({"type": "text", "text": prompt})
|
48 |
+
|
49 |
+
# Tokenize and process
|
50 |
+
text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
51 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
52 |
+
inputs = self._processor(
|
53 |
+
text=[text],
|
54 |
+
images=image_inputs,
|
55 |
+
videos=video_inputs,
|
56 |
+
padding=True,
|
57 |
+
return_tensors="pt",
|
58 |
+
)
|
59 |
+
inputs = inputs.to(self._model.device)
|
60 |
+
|
61 |
+
# Generate output
|
62 |
+
generated_ids = self._model.generate(**inputs, max_new_tokens=self.num_output)
|
63 |
+
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
64 |
+
output_text = self._processor.batch_decode(
|
65 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
66 |
+
)[0]
|
67 |
+
return CompletionResponse(text=output_text)
|
68 |
+
|
69 |
+
@llm_completion_callback()
|
70 |
+
def stream_complete(
|
71 |
+
self,
|
72 |
+
prompt: str,
|
73 |
+
image_paths: Optional[List[str]] = None,
|
74 |
+
**kwargs: Any
|
75 |
+
) -> CompletionResponseGen:
|
76 |
+
response = self.complete(prompt, image_paths)
|
77 |
+
for token in response.text:
|
78 |
+
yield CompletionResponse(text=token, delta=token)
|
79 |
+
|
80 |
+
class MultimodalCLIPEmbedding(BaseEmbedding):
|
81 |
+
"""
|
82 |
+
Custom embedding class using CLIP for multimodal capabilities.
|
83 |
+
"""
|
84 |
+
|
85 |
+
def __init__(self, model_name: str = "clip-ViT-B-32", **kwargs: Any) -> None:
|
86 |
+
super().__init__(**kwargs)
|
87 |
+
self._model = SentenceTransformer(model_name)
|
88 |
+
|
89 |
+
@classmethod
|
90 |
+
def class_name(cls) -> str:
|
91 |
+
return "multimodal_clip"
|
92 |
+
|
93 |
+
def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
94 |
+
if image_path:
|
95 |
+
image = Image.open(image_path)
|
96 |
+
embedding = self._model.encode(image)
|
97 |
+
return embedding.tolist()
|
98 |
+
else:
|
99 |
+
embedding = self._model.encode(query)
|
100 |
+
return embedding.tolist()
|
101 |
+
|
102 |
+
def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
103 |
+
if image_path:
|
104 |
+
image = Image.open(image_path)
|
105 |
+
embedding = self._model.encode(image)
|
106 |
+
return embedding.tolist()
|
107 |
+
else:
|
108 |
+
embedding = self._model.encode(text)
|
109 |
+
return embedding.tolist()
|
110 |
+
|
111 |
+
def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
|
112 |
+
embeddings = []
|
113 |
+
image_paths = image_paths or [None] * len(texts)
|
114 |
+
|
115 |
+
for text, img_path in zip(texts, image_paths):
|
116 |
+
if img_path:
|
117 |
+
image = Image.open(img_path)
|
118 |
+
emb = self._model.encode(image)
|
119 |
+
else:
|
120 |
+
emb = self._model.encode(text)
|
121 |
+
embeddings.append(emb.tolist())
|
122 |
+
|
123 |
+
return embeddings
|
124 |
+
|
125 |
+
async def _aget_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
126 |
+
return self._get_query_embedding(query, image_path)
|
127 |
+
|
128 |
+
async def _aget_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
129 |
+
return self._get_text_embedding(text, image_path)
|
130 |
+
|
131 |
+
# BAAI embedding class
|
132 |
+
# To run on Terminal before running the app, you need to install the FlagEmbedding package.
|
133 |
+
# This can be done by cloning the repository and installing it in editable mode.
|
134 |
+
#!git clone https://github.com/FlagOpen/FlagEmbedding.git
|
135 |
+
#cd FlagEmbedding/research/visual_bge
|
136 |
+
#pip install -e .
|
137 |
+
#go back to the app directory
|
138 |
+
#cd ../../..
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
class BaaiMultimodalEmbedding(BaseEmbedding):
|
143 |
+
"""
|
144 |
+
Custom embedding class using BAAI's FlagEmbedding for multimodal capabilities.
|
145 |
+
Implements the visual_bge Visualized_BGE model with bge-m3 backend.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self,
|
149 |
+
model_name_bge: str = "BAAI/bge-m3",
|
150 |
+
model_weight: str = "Visualized_m3.pth",
|
151 |
+
device: str = "cuda:1",
|
152 |
+
**kwargs: Any) -> None:
|
153 |
+
super().__init__(**kwargs)
|
154 |
+
|
155 |
+
# Set device
|
156 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
157 |
+
print(f"BaaiMultimodalEmbedding initializing on device: {self.device}")
|
158 |
+
|
159 |
+
# Import the visual_bge module
|
160 |
+
from visual_bge.modeling import Visualized_BGE
|
161 |
+
self._model = Visualized_BGE(
|
162 |
+
model_name_bge=model_name_bge,
|
163 |
+
model_weight=model_weight
|
164 |
+
)
|
165 |
+
self._model.to(self.device)
|
166 |
+
self._model.eval()
|
167 |
+
print(f"Successfully loaded BAAI Visualized_BGE with {model_name_bge}")
|
168 |
+
|
169 |
+
@classmethod
|
170 |
+
def class_name(cls) -> str:
|
171 |
+
return "baai_multimodal"
|
172 |
+
|
173 |
+
def _get_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
174 |
+
"""Get embedding for query with optional image"""
|
175 |
+
with torch.no_grad():
|
176 |
+
if hasattr(self._model, 'encode') and hasattr(self._model, 'preprocess_val'):
|
177 |
+
# Using visual_bge
|
178 |
+
if image_path and query:
|
179 |
+
# Combined text and image query
|
180 |
+
embedding = self._model.encode(image=image_path, text=query)
|
181 |
+
elif image_path:
|
182 |
+
# Image only
|
183 |
+
embedding = self._model.encode(image=image_path)
|
184 |
+
else:
|
185 |
+
# Text only
|
186 |
+
embedding = self._model.encode(text=query)
|
187 |
+
else:
|
188 |
+
# Fallback to sentence-transformers
|
189 |
+
if image_path:
|
190 |
+
from PIL import Image
|
191 |
+
image = Image.open(image_path)
|
192 |
+
embedding = self._model.encode(image)
|
193 |
+
else:
|
194 |
+
embedding = self._model.encode(query)
|
195 |
+
|
196 |
+
return embedding.cpu().numpy().tolist() if torch.is_tensor(embedding) else embedding.tolist()
|
197 |
+
|
198 |
+
def _get_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
199 |
+
"""Get embedding for text with optional image"""
|
200 |
+
return self._get_query_embedding(text, image_path)
|
201 |
+
|
202 |
+
def _get_text_embeddings(self, texts: List[str], image_paths: Optional[List[str]] = None) -> List[List[float]]:
|
203 |
+
"""Get embeddings for multiple texts with optional images"""
|
204 |
+
embeddings = []
|
205 |
+
image_paths = image_paths or [None] * len(texts)
|
206 |
+
|
207 |
+
for text, img_path in zip(texts, image_paths):
|
208 |
+
emb = self._get_text_embedding(text, img_path)
|
209 |
+
embeddings.append(emb)
|
210 |
+
return embeddings
|
211 |
+
|
212 |
+
async def _aget_query_embedding(self, query: str, image_path: Optional[str] = None) -> List[float]:
|
213 |
+
return self._get_query_embedding(query, image_path)
|
214 |
+
|
215 |
+
async def _aget_text_embedding(self, text: str, image_path: Optional[str] = None) -> List[float]:
|
216 |
+
return self._get_text_embedding(text, image_path)
|
217 |
+
|
218 |
+
|
219 |
+
class PixtralQuantizedLLM(CustomLLM):
|
220 |
+
"""
|
221 |
+
Pixtral 12B quantized model implementation for Kaggle compatibility.
|
222 |
+
Uses float8 quantization for memory efficiency.
|
223 |
+
"""
|
224 |
+
|
225 |
+
model_name: str = Field(default="mistralai/Pixtral-12B-2409")
|
226 |
+
context_window: int = Field(default=128000)
|
227 |
+
num_output: int = Field(default=512)
|
228 |
+
quantization: str = Field(default="fp8")
|
229 |
+
_model = PrivateAttr()
|
230 |
+
_processor = PrivateAttr()
|
231 |
+
|
232 |
+
def __init__(self, **kwargs):
|
233 |
+
super().__init__(**kwargs)
|
234 |
+
|
235 |
+
# Check if we're in a Kaggle environment or have limited resources
|
236 |
+
import psutil
|
237 |
+
available_memory = psutil.virtual_memory().available / (1024**3) # GB
|
238 |
+
|
239 |
+
if available_memory < 20: # Less than 20GB RAM
|
240 |
+
print(f"Limited memory detected ({available_memory:.1f}GB), using quantized version")
|
241 |
+
self._load_quantized_model()
|
242 |
+
else:
|
243 |
+
print("Sufficient memory available, attempting full model load")
|
244 |
+
try:
|
245 |
+
self._load_full_model()
|
246 |
+
except Exception as e:
|
247 |
+
print(f"Full model loading failed: {e}, falling back to quantized")
|
248 |
+
self._load_quantized_model()
|
249 |
+
|
250 |
+
def _load_quantized_model(self):
|
251 |
+
"""Load quantized Pixtral model for resource-constrained environments"""
|
252 |
+
try:
|
253 |
+
# Try to use a pre-quantized version from HuggingFace
|
254 |
+
quantized_models = [
|
255 |
+
"RedHatAI/pixtral-12b-FP8-dynamic" ]
|
256 |
+
|
257 |
+
model_loaded = False
|
258 |
+
for model_id in quantized_models:
|
259 |
+
try:
|
260 |
+
print(f"Attempting to load quantized model: {model_id}")
|
261 |
+
|
262 |
+
# Standard quantized model loading
|
263 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
264 |
+
self._model = AutoModelForCausalLM.from_pretrained(
|
265 |
+
model_id,
|
266 |
+
torch_dtype=torch.float8,
|
267 |
+
device_map="auto",
|
268 |
+
trust_remote_code=True
|
269 |
+
)
|
270 |
+
self._processor = AutoProcessor.from_pretrained(model_id)
|
271 |
+
|
272 |
+
print(f"Successfully loaded quantized Pixtral: {model_id}")
|
273 |
+
model_loaded = True
|
274 |
+
break
|
275 |
+
|
276 |
+
except Exception as e:
|
277 |
+
print(f"Failed to load {model_id}: {e}")
|
278 |
+
continue
|
279 |
+
|
280 |
+
if not model_loaded:
|
281 |
+
print("All quantized models failed, using CPU-only fallback")
|
282 |
+
self._load_cpu_fallback()
|
283 |
+
|
284 |
+
except Exception as e:
|
285 |
+
print(f"Quantized loading failed: {e}")
|
286 |
+
self._load_cpu_fallback()
|
287 |
+
|
288 |
+
def _load_full_model(self):
|
289 |
+
"""Load full Pixtral model"""
|
290 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
291 |
+
|
292 |
+
self._model = AutoModelForCausalLM.from_pretrained(
|
293 |
+
self.model_name,
|
294 |
+
torch_dtype=torch.bfloat16,
|
295 |
+
device_map="auto",
|
296 |
+
trust_remote_code=True
|
297 |
+
)
|
298 |
+
self._processor = AutoProcessor.from_pretrained(self.model_name)
|
299 |
+
|
300 |
+
def _load_cpu_fallback(self):
|
301 |
+
"""Fallback to CPU-only inference"""
|
302 |
+
try:
|
303 |
+
from transformers import AutoModelForCausalLM, AutoProcessor
|
304 |
+
|
305 |
+
self._model = AutoModelForCausalLM.from_pretrained(
|
306 |
+
"microsoft/DialoGPT-medium", # Smaller fallback model
|
307 |
+
torch_dtype=torch.float32,
|
308 |
+
device_map="cpu"
|
309 |
+
)
|
310 |
+
self._processor = AutoProcessor.from_pretrained("microsoft/DialoGPT-medium")
|
311 |
+
print("Using CPU fallback model (DialoGPT-medium)")
|
312 |
+
|
313 |
+
except Exception as e:
|
314 |
+
print(f"CPU fallback failed: {e}")
|
315 |
+
# Use a minimal implementation
|
316 |
+
self._model = None
|
317 |
+
self._processor = None
|
318 |
+
|
319 |
+
@property
|
320 |
+
def metadata(self) -> LLMMetadata:
|
321 |
+
return LLMMetadata(
|
322 |
+
context_window=self.context_window,
|
323 |
+
num_output=self.num_output,
|
324 |
+
model_name=f"{self.model_name}-{self.quantization}",
|
325 |
+
)
|
326 |
+
|
327 |
+
@llm_completion_callback()
|
328 |
+
def complete(
|
329 |
+
self,
|
330 |
+
prompt: str,
|
331 |
+
image_paths: Optional[List[str]] = None,
|
332 |
+
**kwargs: Any
|
333 |
+
) -> CompletionResponse:
|
334 |
+
|
335 |
+
if self._model is None:
|
336 |
+
return CompletionResponse(text="Model not available in current environment")
|
337 |
+
|
338 |
+
try:
|
339 |
+
# Prepare multimodal input if images provided
|
340 |
+
if image_paths and hasattr(self._processor, 'apply_chat_template'):
|
341 |
+
# Handle multimodal input
|
342 |
+
messages = [{"role": "user", "content": []}]
|
343 |
+
|
344 |
+
if image_paths:
|
345 |
+
for path in image_paths[:4]: # Limit to 4 images for memory
|
346 |
+
messages[0]["content"].append({"type": "image", "image": path})
|
347 |
+
|
348 |
+
messages[0]["content"].append({"type": "text", "text": prompt})
|
349 |
+
|
350 |
+
# Process the input
|
351 |
+
inputs = self._processor(messages, return_tensors="pt", padding=True)
|
352 |
+
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
|
353 |
+
|
354 |
+
# Generate
|
355 |
+
with torch.no_grad():
|
356 |
+
outputs = self._model.generate(
|
357 |
+
**inputs,
|
358 |
+
max_new_tokens=min(self.num_output, 256), # Limit for memory
|
359 |
+
do_sample=True,
|
360 |
+
temperature=0.7,
|
361 |
+
pad_token_id=self._processor.tokenizer.eos_token_id
|
362 |
+
)
|
363 |
+
|
364 |
+
# Decode response
|
365 |
+
response = self._processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
366 |
+
# Extract only the new generated part
|
367 |
+
if len(messages[0]["content"]) > 0:
|
368 |
+
response = response.split(prompt)[-1].strip()
|
369 |
+
|
370 |
+
else:
|
371 |
+
# Text-only fallback
|
372 |
+
inputs = self._processor(prompt, return_tensors="pt", padding=True)
|
373 |
+
inputs = {k: v.to(self._model.device) for k, v in inputs.items()}
|
374 |
+
|
375 |
+
with torch.no_grad():
|
376 |
+
outputs = self._model.generate(
|
377 |
+
**inputs,
|
378 |
+
max_new_tokens=min(self.num_output, 256),
|
379 |
+
do_sample=True,
|
380 |
+
temperature=0.7,
|
381 |
+
pad_token_id=self._processor.tokenizer.eos_token_id
|
382 |
+
)
|
383 |
+
|
384 |
+
response = self._processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
385 |
+
response = response.replace(prompt, "").strip()
|
386 |
+
|
387 |
+
return CompletionResponse(text=response)
|
388 |
+
|
389 |
+
except Exception as e:
|
390 |
+
error_msg = f"Generation error: {str(e)}"
|
391 |
+
print(error_msg)
|
392 |
+
return CompletionResponse(text=error_msg)
|
393 |
+
|
394 |
+
@llm_completion_callback()
|
395 |
+
def stream_complete(
|
396 |
+
self,
|
397 |
+
prompt: str,
|
398 |
+
image_paths: Optional[List[str]] = None,
|
399 |
+
**kwargs: Any
|
400 |
+
) -> CompletionResponseGen:
|
401 |
+
# For quantized models, streaming might not be efficient
|
402 |
+
# Return the complete response as a single chunk
|
403 |
+
response = self.complete(prompt, image_paths, **kwargs)
|
404 |
+
yield response
|