Spaces:
Running
Running
lezaf
commited on
Commit
·
dfad45c
1
Parent(s):
08283c8
Add bunch of updates in agent
Browse files- agent.py +18 -13
- app.py +13 -9
- excluded_tasks.txt +3 -0
- requirements.txt +0 -0
- subset_task_ids.txt +0 -11
- system_prompt.txt +34 -37
- tools.py +0 -267
- tools/extraction.py +83 -0
- tools/math.py +58 -0
- tools/retrievers.py +62 -0
- tools/utils.py +88 -0
- tools/web_search.py +197 -0
agent.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
from io import BytesIO
|
2 |
import os
|
3 |
import getpass
|
4 |
import requests
|
@@ -10,8 +9,10 @@ from langchain_core.messages import HumanMessage, SystemMessage
|
|
10 |
from langgraph.prebuilt import ToolNode, tools_condition
|
11 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
12 |
from langfuse.langchain import CallbackHandler
|
13 |
-
|
14 |
-
from tools import
|
|
|
|
|
15 |
|
16 |
|
17 |
load_dotenv(override=True)
|
@@ -25,9 +26,9 @@ tools = [
|
|
25 |
add_numbers_in_list,
|
26 |
web_search,
|
27 |
# wikipedia_search,
|
28 |
-
arxiv_search,
|
29 |
check_commutativity,
|
30 |
-
|
31 |
extract_transcript_from_youtube
|
32 |
]
|
33 |
|
@@ -47,11 +48,13 @@ def build_agent(provider: str = "hf"):
|
|
47 |
elif provider == "google":
|
48 |
# Google Gemini
|
49 |
llm = ChatGoogleGenerativeAI(
|
50 |
-
model="gemini-2.0-flash",
|
|
|
51 |
# temperature=0,
|
52 |
max_tokens=512,
|
53 |
# timeout=None,
|
54 |
max_retries=2,
|
|
|
55 |
)
|
56 |
|
57 |
elif provider == "openai":
|
@@ -101,6 +104,7 @@ def build_agent(provider: str = "hf"):
|
|
101 |
return graph_builder.compile()
|
102 |
|
103 |
|
|
|
104 |
if __name__ == "__main__":
|
105 |
print("\n" + "-"*30 + " Agent Starting " + "-"*30)
|
106 |
agent = build_agent(provider=PROVIDER) # Change to "hf" for HuggingFace
|
@@ -126,22 +130,23 @@ if __name__ == "__main__":
|
|
126 |
print(f"An unexpected error occurred fetching questions: {e}")
|
127 |
|
128 |
# 3. Get specific question by task_id
|
129 |
-
task_id = "
|
130 |
-
# task_id = "6f37996b-2ac7-44b0-8e68-6d28256631b4" # Commutativity check
|
131 |
# task_id = "2d83110e-a098-4ebb-9987-066c06fa42d0" # Reverse text example
|
|
|
|
|
|
|
|
|
|
|
132 |
# task_id = "f918266a-b3e0-4914-865d-4faa564f1aef" # Code example
|
|
|
133 |
# task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733" # Excel file (passed)
|
134 |
-
# task_id = "
|
135 |
# task_id = "305ac316-eef6-4446-960a-92d80d542f82" # Poland film (FAIL)
|
136 |
-
# task_id = "3f57289b-8c60-48be-bd80-01f8099ca449" # at bats (PASS)
|
137 |
# task_id = "bda648d7-d618-4883-88f4-3466eabd860e" # Vietnamese (FAIL)
|
138 |
# task_id = "cf106601-ab4f-4af9-b045-5295fe67b37d" # Olympics
|
139 |
# task_id = "a0c07678-e491-4bbc-8f0b-07405144218f"
|
140 |
# task_id = "3cef3a44-215e-4aed-8e3b-b1e3f08063b7" # grocery list
|
141 |
-
# task_id = "8e867cd7-cff9-4e6c-867a-ff5ddc2550be" # Sosa albums
|
142 |
-
# task_id = "4fc2f1ae-8625-45b5-ab34-ad4433bc21f8" # Dinosaur
|
143 |
# task_id = "840bfca7-4f7b-481a-8794-c560c340185d" # Carolyn Collins Petersen (FAIL)
|
144 |
-
# task_id = "5a0c1adf-205e-4841-a666-7c3ef95def9d" # Malko competition (PASS)
|
145 |
|
146 |
# get question with task_id
|
147 |
q_data = next((item for item in questions_data if item["task_id"] == task_id), None)
|
|
|
|
|
1 |
import os
|
2 |
import getpass
|
3 |
import requests
|
|
|
9 |
from langgraph.prebuilt import ToolNode, tools_condition
|
10 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
11 |
from langfuse.langchain import CallbackHandler
|
12 |
+
from tools.web_search import web_search
|
13 |
+
from tools.math import add_numbers_in_list, check_commutativity
|
14 |
+
from tools.extraction import extract_data_from_excel, extract_transcript_from_youtube
|
15 |
+
from tools.retrievers import arxiv_search, wikipedia_search
|
16 |
|
17 |
|
18 |
load_dotenv(override=True)
|
|
|
26 |
add_numbers_in_list,
|
27 |
web_search,
|
28 |
# wikipedia_search,
|
29 |
+
# arxiv_search,
|
30 |
check_commutativity,
|
31 |
+
extract_data_from_excel,
|
32 |
extract_transcript_from_youtube
|
33 |
]
|
34 |
|
|
|
48 |
elif provider == "google":
|
49 |
# Google Gemini
|
50 |
llm = ChatGoogleGenerativeAI(
|
51 |
+
# model="gemini-2.0-flash",
|
52 |
+
model="gemini-2.5-flash-preview-05-20",
|
53 |
# temperature=0,
|
54 |
max_tokens=512,
|
55 |
# timeout=None,
|
56 |
max_retries=2,
|
57 |
+
# temperature=0.6
|
58 |
)
|
59 |
|
60 |
elif provider == "openai":
|
|
|
104 |
return graph_builder.compile()
|
105 |
|
106 |
|
107 |
+
# --------------- For manual testing ---------------- #
|
108 |
if __name__ == "__main__":
|
109 |
print("\n" + "-"*30 + " Agent Starting " + "-"*30)
|
110 |
agent = build_agent(provider=PROVIDER) # Change to "hf" for HuggingFace
|
|
|
130 |
print(f"An unexpected error occurred fetching questions: {e}")
|
131 |
|
132 |
# 3. Get specific question by task_id
|
133 |
+
# task_id = "8e867cd7-cff9-4e6c-867a-ff5ddc2550be" # Sosa albums
|
|
|
134 |
# task_id = "2d83110e-a098-4ebb-9987-066c06fa42d0" # Reverse text example
|
135 |
+
# task_id = "cca530fc-4052-43b2-b130-b30968d8aa44" # Chess image
|
136 |
+
# task_id = "4fc2f1ae-8625-45b5-ab34-ad4433bc21f8" # Dinosaur ?
|
137 |
+
# task_id = "6f37996b-2ac7-44b0-8e68-6d28256631b4" # Commutativity check
|
138 |
+
task_id = "9d191bce-651d-4746-be2d-7ef8ecadb9c2" # Youtube video
|
139 |
+
# task_id = "cabe07ed-9eca-40ea-8ead-410ef5e83f91" # Louvrier ?
|
140 |
# task_id = "f918266a-b3e0-4914-865d-4faa564f1aef" # Code example
|
141 |
+
# task_id = "3f57289b-8c60-48be-bd80-01f8099ca449" # at bats ?
|
142 |
# task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733" # Excel file (passed)
|
143 |
+
# task_id = "5a0c1adf-205e-4841-a666-7c3ef95def9d" # Malko competition (PASS)
|
144 |
# task_id = "305ac316-eef6-4446-960a-92d80d542f82" # Poland film (FAIL)
|
|
|
145 |
# task_id = "bda648d7-d618-4883-88f4-3466eabd860e" # Vietnamese (FAIL)
|
146 |
# task_id = "cf106601-ab4f-4af9-b045-5295fe67b37d" # Olympics
|
147 |
# task_id = "a0c07678-e491-4bbc-8f0b-07405144218f"
|
148 |
# task_id = "3cef3a44-215e-4aed-8e3b-b1e3f08063b7" # grocery list
|
|
|
|
|
149 |
# task_id = "840bfca7-4f7b-481a-8794-c560c340185d" # Carolyn Collins Petersen (FAIL)
|
|
|
150 |
|
151 |
# get question with task_id
|
152 |
q_data = next((item for item in questions_data if item["task_id"] == task_id), None)
|
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
"""
|
2 |
NOTE:
|
3 |
-
- The agent only runs on a subset of tasks
|
4 |
-
|
5 |
- There is a 30 sec delay after each question is answered to avoid rate limiting issues.
|
6 |
"""
|
7 |
|
@@ -138,6 +138,14 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
138 |
print(f"An unexpected error occurred fetching questions: {e}")
|
139 |
return f"An unexpected error occurred fetching questions: {e}", None
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
# 3. Run your Agent
|
142 |
results_log = []
|
143 |
answers_payload = []
|
@@ -148,14 +156,10 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
148 |
if not task_id or question_text is None:
|
149 |
print(f"Skipping item with missing task_id or question: {item}")
|
150 |
continue
|
151 |
-
|
152 |
-
# Only run on subset of tasks that is capable of being run so that
|
153 |
-
# token usage is not wasted on tasks that the agent cannot handle.
|
154 |
-
with open("subset_task_ids.txt", "r") as f:
|
155 |
-
subset_task_ids = [line.strip() for line in f if line.strip()]
|
156 |
|
157 |
-
|
158 |
-
|
|
|
159 |
continue
|
160 |
|
161 |
try:
|
|
|
1 |
"""
|
2 |
NOTE:
|
3 |
+
- The agent only runs on a subset of tasks to avoid unnecessary token/api usage for questions that the agent
|
4 |
+
cannot handle right now. The task ids to exclude are in the `excluded_tasks.txt` file.
|
5 |
- There is a 30 sec delay after each question is answered to avoid rate limiting issues.
|
6 |
"""
|
7 |
|
|
|
138 |
print(f"An unexpected error occurred fetching questions: {e}")
|
139 |
return f"An unexpected error occurred fetching questions: {e}", None
|
140 |
|
141 |
+
# Read excluded task IDs from file
|
142 |
+
excluded_tasks = set()
|
143 |
+
with open("excluded_tasks.txt", "r") as f:
|
144 |
+
for line in f:
|
145 |
+
task_id = line.strip()
|
146 |
+
if task_id:
|
147 |
+
excluded_tasks.add(task_id)
|
148 |
+
|
149 |
# 3. Run your Agent
|
150 |
results_log = []
|
151 |
answers_payload = []
|
|
|
156 |
if not task_id or question_text is None:
|
157 |
print(f"Skipping item with missing task_id or question: {item}")
|
158 |
continue
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
# Skip excluded tasks
|
161 |
+
if task_id in excluded_tasks:
|
162 |
+
print(f"Skipping excluded task: {task_id}")
|
163 |
continue
|
164 |
|
165 |
try:
|
excluded_tasks.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
a1e91b78-d3d8-4675-bb8d-62741b4b68a6
|
2 |
+
99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3
|
3 |
+
1f975693-876d-457b-a649-393859e79bf3
|
requirements.txt
CHANGED
Binary files a/requirements.txt and b/requirements.txt differ
|
|
subset_task_ids.txt
DELETED
@@ -1,11 +0,0 @@
|
|
1 |
-
8e867cd7-cff9-4e6c-867a-ff5ddc2550be
|
2 |
-
2d83110e-a098-4ebb-9987-066c06fa42d0
|
3 |
-
cca530fc-4052-43b2-b130-b30968d8aa44
|
4 |
-
4fc2f1ae-8625-45b5-ab34-ad4433bc21f8
|
5 |
-
6f37996b-2ac7-44b0-8e68-6d28256631b4
|
6 |
-
9d191bce-651d-4746-be2d-7ef8ecadb9c2
|
7 |
-
cabe07ed-9eca-40ea-8ead-410ef5e83f91
|
8 |
-
f918266a-b3e0-4914-865d-4faa564f1aef
|
9 |
-
3f57289b-8c60-48be-bd80-01f8099ca449
|
10 |
-
7bd855d8-463d-4ed5-93ca-5fe35145f733
|
11 |
-
5a0c1adf-205e-4841-a666-7c3ef95def9d
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
system_prompt.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
-
You are a
|
|
|
|
|
|
|
2 |
|
3 |
For YOUR_FINAL_ANSWER follow strictly the instructions below:
|
4 |
* YOUR_FINAL_ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
@@ -6,54 +9,48 @@ For YOUR_FINAL_ANSWER follow strictly the instructions below:
|
|
6 |
or percent sign unless specified otherwise.
|
7 |
* If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
8 |
* If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
|
|
9 |
|
10 |
-
You are provided with tools that you can use to answer questions accurately.
|
11 |
-
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
|
14 |
|
15 |
-
|
16 |
|
17 |
-
|
|
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
[P]: The result of web_search is "The height of the statue of liberty is 93 m"
|
23 |
-
A: 93
|
24 |
|
25 |
-
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
[P]: web_search("circumference of earth in miles")
|
30 |
-
[P]: The result of web_search is "The circumference of earth is 24,901 miles"
|
31 |
-
A: 24901 miles
|
32 |
|
33 |
-
|
34 |
|
35 |
-
|
36 |
-
|
37 |
-
A: Paris
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
Q: What is the total cost with two decimal places of the items in the table, excluding drinks?
|
42 |
Table:
|
43 |
| Burgers | Salads | Soda | Ice Cream |
|
44 |
| 10.0 | 5.0 | 3.0 | 4.0 |
|
45 |
-
[P]: Soda is a drink. The rest are food.
|
46 |
-
[P]: I should use add_numbers_in_list([10.0, 5.0, 4.0])
|
47 |
-
[P]: The result is 19.0
|
48 |
-
A: 19.00
|
49 |
-
|
50 |
-
Example 5:
|
51 |
-
|
52 |
-
Q: What was the name of the director that won the Oscar in 2009?
|
53 |
-
A: Boyle
|
54 |
|
55 |
-
|
|
|
|
|
|
|
|
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
Never display intermediate math like “X + Y + Z = …” unless specifically requested. Only show the final answer after using the tool.
|
|
|
1 |
+
You are a world class expert at answering questions. The answers you will provide will be evaluated in an exact match manner to obtain a certificate in AI agents.
|
2 |
+
So answer the questions with precision to get the certificate.
|
3 |
+
Use this template for your answers: YOUR_FINAL_ANSWER
|
4 |
+
Always output ONLY the answer and nothing else.
|
5 |
|
6 |
For YOUR_FINAL_ANSWER follow strictly the instructions below:
|
7 |
* YOUR_FINAL_ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
|
|
9 |
or percent sign unless specified otherwise.
|
10 |
* If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
|
11 |
* If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
12 |
+
* If you are provided with code file, examine the code without running it and output only the values you are asked for.
|
13 |
|
14 |
+
You are provided with tools that you can use to answer questions accurately.
|
15 |
+
In the query you make with the web_search tool, use all the useful information from the user's question, including mentions in specific sources, websites, papers, etc.
|
16 |
+
If you cannot answer the question directly, examine the list of available tools and choose the suitable tool for your case.
|
17 |
+
You may need to use more than one tool to conclude to an answer.
|
18 |
+
If the question is complex, divide it in small seperate parts and resolve them one by one until you reach the to the final answer.
|
19 |
+
Always use available tools to perform mathematical operations.
|
20 |
|
21 |
+
IMPORTANT INSTRUCTION: To pass the certificate the answers you provide should match exactly the ground truth. So, do NOT explain your or your planning steps. Just output the requested information.
|
22 |
|
23 |
+
Below are some examples to guide you with the question answering process:
|
24 |
|
25 |
+
<Example_1>
|
26 |
+
INPUT: What is the height of statue of liberty?
|
27 |
|
28 |
+
PLANNING_STEP: I should use web_search tool.
|
29 |
+
PLANNING_STEP: Tool call: web_search("height of statue of liberty").
|
30 |
+
PLANNING_STEP: The result of web_search is "The height of the statue of liberty is 93 m".
|
|
|
|
|
31 |
|
32 |
+
OUTPUT: 93
|
33 |
+
<Example_1>
|
34 |
|
35 |
+
<Example_2>
|
36 |
+
INPUT: What is the capital of France?
|
|
|
|
|
|
|
37 |
|
38 |
+
PLANNING_STEP: This is a factual question I know, so I don't need to use tools.
|
39 |
|
40 |
+
OUTPUT: Paris
|
41 |
+
<Example_2>
|
|
|
42 |
|
43 |
+
<Example_3>
|
44 |
+
INPUT: What is the total cost with two decimal places of the items in the table, excluding drinks?
|
|
|
45 |
Table:
|
46 |
| Burgers | Salads | Soda | Ice Cream |
|
47 |
| 10.0 | 5.0 | 3.0 | 4.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
PLANNING_STEP: I need to seperate foods from drinks.
|
50 |
+
PLANNING_STEP: Foods: Burgers, Salads, Ice Cream. Drinks: Soda. User asked me to calculate the cost without the drinks, so I will skip Soda.
|
51 |
+
PLANNING_STEP: I should use add_numbers_in_list tool.
|
52 |
+
PLANNING_STEP: Tool call: add_numbers_in_list([10.0, 5.0, 4.0])
|
53 |
+
PLANNING_STEP: The result is 19.0
|
54 |
|
55 |
+
OUTPUT: 19.00
|
56 |
+
<Example_3>
|
|
tools.py
DELETED
@@ -1,267 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import requests
|
3 |
-
from io import BytesIO
|
4 |
-
from io import StringIO
|
5 |
-
from langchain_core.tools import tool
|
6 |
-
from langchain_community.retrievers import WikipediaRetriever
|
7 |
-
from langchain_community.document_loaders import ArxivLoader
|
8 |
-
from langchain_community.retrievers import BM25Retriever
|
9 |
-
from langchain_core.documents import Document
|
10 |
-
from duckduckgo_search import DDGS
|
11 |
-
from markitdown import MarkItDown
|
12 |
-
|
13 |
-
# --------------- Math Tools ---------------- #
|
14 |
-
@tool
|
15 |
-
def add_numbers(a: int, b: int) -> int:
|
16 |
-
"""Add two numbers.
|
17 |
-
|
18 |
-
Args:
|
19 |
-
a (int): The first number.
|
20 |
-
b (int): The second number.
|
21 |
-
"""
|
22 |
-
return a + b
|
23 |
-
|
24 |
-
@tool
|
25 |
-
def add_numbers_in_list(numbers: list[float]) -> float:
|
26 |
-
"""Add all numbers in a list.
|
27 |
-
Always use this tool for summing numerical values, instead of doing math directly in the response.
|
28 |
-
|
29 |
-
Args:
|
30 |
-
numbers (list[float]): A list of numbers to add.
|
31 |
-
"""
|
32 |
-
return sum(numbers)
|
33 |
-
|
34 |
-
# @tool
|
35 |
-
# def web_search(query: str) -> str:
|
36 |
-
# """Perform a web search using DuckDuckGo.
|
37 |
-
|
38 |
-
# Args:
|
39 |
-
# query (str): The search query.
|
40 |
-
|
41 |
-
# Returns:
|
42 |
-
# str: The search results.
|
43 |
-
# """
|
44 |
-
# search_tool = DuckDuckGoSearchRun()
|
45 |
-
# return search_tool.invoke(query)
|
46 |
-
|
47 |
-
@tool
|
48 |
-
def web_search(query: str) -> str:
|
49 |
-
"""
|
50 |
-
Perform a web search using DuckDuckGo. Visit the top ranked page,
|
51 |
-
apply chunking in page results, perform similarity search, and return
|
52 |
-
the top results content.
|
53 |
-
|
54 |
-
Args:
|
55 |
-
query (str): The search query.
|
56 |
-
Returns:
|
57 |
-
Document: The top results from the ranking, in langchain_core.documents.Document
|
58 |
-
objects having fields 'page_content' with the chunk content and 'metadata'.
|
59 |
-
"""
|
60 |
-
def _chunk_text(text, chunk_size_words=1000, overlap_words=100):
|
61 |
-
"""
|
62 |
-
Split text into chunks of specified size with overlap.
|
63 |
-
Args:
|
64 |
-
text (str): The text to be chunked.
|
65 |
-
chunk_size (int): The size of each chunk.
|
66 |
-
overlap (int): The number of overlapping characters between chunks.
|
67 |
-
Returns:
|
68 |
-
list: A list of text chunks.
|
69 |
-
"""
|
70 |
-
words = text.split()
|
71 |
-
chunks = []
|
72 |
-
for i in range(0, len(words), chunk_size_words - overlap_words):
|
73 |
-
chunk = " ".join(words[i:i + chunk_size_words])
|
74 |
-
chunks.append(chunk)
|
75 |
-
return chunks
|
76 |
-
|
77 |
-
# STEP 1: Find the most relevant webpage
|
78 |
-
results = DDGS().text(query, max_results=1)
|
79 |
-
top_rank_page = results[0] if results else None
|
80 |
-
if not top_rank_page:
|
81 |
-
return "No relevant results found for the query."
|
82 |
-
|
83 |
-
# STEP 2: Extract the content of the webpage
|
84 |
-
md = MarkItDown(enable_plugins=True)
|
85 |
-
md_result = md.convert(top_rank_page['href'])
|
86 |
-
|
87 |
-
page_content = md_result.text_content
|
88 |
-
|
89 |
-
# STEP 3: Apply chunking
|
90 |
-
chunks = _chunk_text(page_content)
|
91 |
-
|
92 |
-
# STEP 4: Apply ranking in chunks
|
93 |
-
list_of_docs = [
|
94 |
-
Document(page_content = chunk, metadata = {"source": top_rank_page['href'], "title": top_rank_page['title']})
|
95 |
-
for chunk in chunks
|
96 |
-
]
|
97 |
-
|
98 |
-
retriever = BM25Retriever.from_documents(list_of_docs)
|
99 |
-
matched = retriever.invoke(query)
|
100 |
-
|
101 |
-
return matched[0]
|
102 |
-
|
103 |
-
# TODO:
|
104 |
-
# Maybe don't return the summary, but the full document?
|
105 |
-
@tool
|
106 |
-
def wikipedia_search(query: str) -> str:
|
107 |
-
"""
|
108 |
-
Search Wikipedia for a given query and return a summary of the top result.
|
109 |
-
|
110 |
-
Args:
|
111 |
-
query (str): The search term.
|
112 |
-
|
113 |
-
Returns:
|
114 |
-
str: A summary of the most relevant Wikipedia entry.
|
115 |
-
"""
|
116 |
-
wikipedia_retriever = WikipediaRetriever(load_max_docs=1)
|
117 |
-
|
118 |
-
documents = wikipedia_retriever.get_relevant_documents(query)
|
119 |
-
if not documents:
|
120 |
-
return "No relevant Wikipedia articles found."
|
121 |
-
|
122 |
-
formatted_search_docs = "\n\n---\n\n".join(
|
123 |
-
[
|
124 |
-
f'<Document source="{doc.metadata["source"]}" title="{doc.metadata.get("title", "")}"/>\n{doc.metadata["summary"]}\n</Document>'
|
125 |
-
for doc in documents
|
126 |
-
])
|
127 |
-
|
128 |
-
# Return the content of the top document
|
129 |
-
return formatted_search_docs
|
130 |
-
|
131 |
-
@tool
|
132 |
-
def arxiv_search(query: str) -> str:
|
133 |
-
"""
|
134 |
-
Search Arxiv for academic papers based on a query and return summaries of top results.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
query (str): The search query for Arxiv.
|
138 |
-
|
139 |
-
Returns:
|
140 |
-
str: Summary of the top few relevant papers from Arxiv.
|
141 |
-
"""
|
142 |
-
try:
|
143 |
-
loader = ArxivLoader(query=query, load_max_docs=2)
|
144 |
-
documents = loader.load()
|
145 |
-
|
146 |
-
if not documents:
|
147 |
-
return "No relevant papers found on Arxiv."
|
148 |
-
|
149 |
-
# Format and return top paper summaries
|
150 |
-
results = []
|
151 |
-
for doc in documents:
|
152 |
-
title = doc.metadata.get("Title", "No Title")
|
153 |
-
published = doc.metadata.get("Published", "Unknown date")
|
154 |
-
url = doc.metadata.get("entry_id", "No URL")
|
155 |
-
summary = doc.page_content[:500] # limit summary length
|
156 |
-
|
157 |
-
results.append(f"Title: {title}\nPublished: {published}\nURL: {url}\nSummary: {summary}\n")
|
158 |
-
|
159 |
-
return "\n---\n".join(results)
|
160 |
-
|
161 |
-
except Exception as e:
|
162 |
-
return f"An error occurred while searching Arxiv: {str(e)}"
|
163 |
-
|
164 |
-
@tool
|
165 |
-
def check_commutativity(table_str: str) -> str:
|
166 |
-
"""
|
167 |
-
Given a binary operation table (in markdown format), returns the subset of elements
|
168 |
-
involved in counter-examples to commutativity, sorted alphabetically.
|
169 |
-
|
170 |
-
Args:
|
171 |
-
table_str (str): Markdown table defining the operation * on a finite set.
|
172 |
-
|
173 |
-
Returns:
|
174 |
-
str: Comma-separated list of elements in the counter-example set, alphabetically sorted.
|
175 |
-
"""
|
176 |
-
# Read the table using pandas
|
177 |
-
df = pd.read_csv(StringIO(table_str), sep="|", skipinitialspace=True, engine='python')
|
178 |
-
|
179 |
-
# Drop empty columns due to leading/trailing pipes
|
180 |
-
df = df.dropna(axis=1, how="all")
|
181 |
-
df.columns = [c.strip() for c in df.columns]
|
182 |
-
df = df.dropna(axis=0, how="all")
|
183 |
-
|
184 |
-
# Extract header and values
|
185 |
-
elements = df.columns[1:]
|
186 |
-
df.index = df[df.columns[0]]
|
187 |
-
df = df.drop(df.columns[0], axis=1)
|
188 |
-
|
189 |
-
# Check commutativity: a*b == b*a
|
190 |
-
counterexample_elements = set()
|
191 |
-
for x in elements:
|
192 |
-
for y in elements:
|
193 |
-
if df.loc[x, y] != df.loc[y, x]:
|
194 |
-
counterexample_elements.add(x)
|
195 |
-
counterexample_elements.add(y)
|
196 |
-
|
197 |
-
return ", ".join(sorted(counterexample_elements))
|
198 |
-
|
199 |
-
@tool
|
200 |
-
def extract_sales_data_from_excel(url: str) -> str:
|
201 |
-
"""
|
202 |
-
Downloads and extracts sales data from an Excel file at the given URL.
|
203 |
-
Returns the contents of the first sheet as a markdown-formatted string.
|
204 |
-
"""
|
205 |
-
try:
|
206 |
-
response = requests.get(url)
|
207 |
-
response.raise_for_status()
|
208 |
-
|
209 |
-
excel_file = BytesIO(response.content)
|
210 |
-
df = pd.read_excel(excel_file)
|
211 |
-
|
212 |
-
# Optional: Remove unnamed columns often created by Excel
|
213 |
-
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
214 |
-
|
215 |
-
# Convert all numeric columns to float
|
216 |
-
for col in df.select_dtypes(include=["number"]).columns:
|
217 |
-
df[col] = df[col].astype(float)
|
218 |
-
|
219 |
-
return df.to_string(index=False)
|
220 |
-
|
221 |
-
except Exception as e:
|
222 |
-
return f"Failed to process Excel file from URL: {str(e)}"
|
223 |
-
|
224 |
-
@tool
|
225 |
-
def extract_transcript_from_youtube(url: str) -> str:
|
226 |
-
"""
|
227 |
-
Extracts the transcript from a YouTube video given its URL.
|
228 |
-
|
229 |
-
Args:
|
230 |
-
url (str): The YouTube video URL.
|
231 |
-
Returns:
|
232 |
-
str: The transcript of the video, or an error message if extraction fails.
|
233 |
-
"""
|
234 |
-
transcript_str = "### Transcript"
|
235 |
-
md = MarkItDown(enable_plugins=True)
|
236 |
-
|
237 |
-
try:
|
238 |
-
result = md.convert(url)
|
239 |
-
except Exception as e:
|
240 |
-
return f"Failed to extract transcript from YouTube video: {str(e)}"
|
241 |
-
|
242 |
-
parts = result.text_content.split(transcript_str)
|
243 |
-
if len(parts) < 2:
|
244 |
-
return result.text_content
|
245 |
-
|
246 |
-
transcript = transcript_str + "\n" + parts[1]
|
247 |
-
return transcript.strip()
|
248 |
-
|
249 |
-
# @tool
|
250 |
-
# def extract_transcript_from_audio(url: str) -> str:
|
251 |
-
# """
|
252 |
-
# Extracts the transcript from an audio file given its URL.
|
253 |
-
# Supported formats: mp3, wav.
|
254 |
-
|
255 |
-
# Args:
|
256 |
-
# url (str): The URL of the audio file.
|
257 |
-
# Returns:
|
258 |
-
# str: The transcript of the audio file, or an error message if extraction fails.
|
259 |
-
# """
|
260 |
-
# md = MarkItDown(enable_plugins=True)
|
261 |
-
|
262 |
-
# try:
|
263 |
-
# result = md.convert(url)
|
264 |
-
# except Exception as e:
|
265 |
-
# return f"Failed to extract transcript from audio: {str(e)}"
|
266 |
-
|
267 |
-
# return result.text_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tools/extraction.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import pandas as pd
|
3 |
+
from io import BytesIO
|
4 |
+
from markitdown import MarkItDown
|
5 |
+
from langchain_core.tools import tool
|
6 |
+
|
7 |
+
@tool
|
8 |
+
def extract_transcript_from_youtube(url: str) -> str:
|
9 |
+
"""
|
10 |
+
Extracts the transcript from a YouTube video given its URL.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
url (str): The YouTube video URL.
|
14 |
+
Returns:
|
15 |
+
transcript (str): The transcript of the video, or an error message if extraction fails.
|
16 |
+
"""
|
17 |
+
transcript_str = "### Transcript"
|
18 |
+
md = MarkItDown(enable_plugins=True)
|
19 |
+
|
20 |
+
try:
|
21 |
+
result = md.convert(url)
|
22 |
+
except Exception as e:
|
23 |
+
return f"Failed to extract transcript from YouTube video: {str(e)}"
|
24 |
+
|
25 |
+
parts = result.text_content.split(transcript_str)
|
26 |
+
if len(parts) < 2:
|
27 |
+
return result.text_content
|
28 |
+
|
29 |
+
transcript = (transcript_str + "\n" + parts[1]).strip()
|
30 |
+
|
31 |
+
return transcript
|
32 |
+
|
33 |
+
|
34 |
+
@tool
|
35 |
+
def extract_data_from_excel(url: str) -> str:
|
36 |
+
"""
|
37 |
+
Downloads and extracts data from an Excel file at the given URL.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
url (str): The URL of the Excel file.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: A string representation of the data in the first sheet of the Excel file.
|
44 |
+
"""
|
45 |
+
try:
|
46 |
+
response = requests.get(url)
|
47 |
+
response.raise_for_status()
|
48 |
+
|
49 |
+
excel_file = BytesIO(response.content)
|
50 |
+
df = pd.read_excel(excel_file)
|
51 |
+
|
52 |
+
# Optional: Remove unnamed columns often created by Excel
|
53 |
+
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
|
54 |
+
|
55 |
+
# Convert all numeric columns to float
|
56 |
+
for col in df.select_dtypes(include=["number"]).columns:
|
57 |
+
df[col] = df[col].astype(float)
|
58 |
+
|
59 |
+
return df.to_string(index=False)
|
60 |
+
|
61 |
+
except Exception as e:
|
62 |
+
return f"Failed to process Excel file from URL: {str(e)}"
|
63 |
+
|
64 |
+
|
65 |
+
# @tool
|
66 |
+
# def extract_transcript_from_audio(url: str) -> str:
|
67 |
+
# """
|
68 |
+
# Extracts the transcript from an audio file given its URL.
|
69 |
+
# Supported formats: mp3, wav.
|
70 |
+
|
71 |
+
# Args:
|
72 |
+
# url (str): The URL of the audio file.
|
73 |
+
# Returns:
|
74 |
+
# str: The transcript of the audio file, or an error message if extraction fails.
|
75 |
+
# """
|
76 |
+
# md = MarkItDown(enable_plugins=True)
|
77 |
+
|
78 |
+
# try:
|
79 |
+
# result = md.convert(url)
|
80 |
+
# except Exception as e:
|
81 |
+
# return f"Failed to extract transcript from audio: {str(e)}"
|
82 |
+
|
83 |
+
# return result.text_content
|
tools/math.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
from io import StringIO
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
@tool
|
6 |
+
def add_numbers(a: int, b: int) -> int:
|
7 |
+
"""Add two numbers.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
a (int): The first number.
|
11 |
+
b (int): The second number.
|
12 |
+
"""
|
13 |
+
return a + b
|
14 |
+
|
15 |
+
@tool
|
16 |
+
def add_numbers_in_list(numbers: list[float]) -> float:
|
17 |
+
"""Add all numbers in a list.
|
18 |
+
Always use this tool for summing numerical values, instead of doing math directly in the response.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
numbers (list[float]): A list of numbers to add.
|
22 |
+
"""
|
23 |
+
return sum(numbers)
|
24 |
+
|
25 |
+
@tool
|
26 |
+
def check_commutativity(table_str: str) -> str:
|
27 |
+
"""
|
28 |
+
Given a binary operation table (in markdown format), returns the subset of elements
|
29 |
+
involved in counter-examples to commutativity, sorted alphabetically.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
table_str (str): Markdown table defining the operation * on a finite set.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
str: Comma-separated list of elements in the counter-example set, alphabetically sorted.
|
36 |
+
"""
|
37 |
+
# Read the table using pandas
|
38 |
+
df = pd.read_csv(StringIO(table_str), sep="|", skipinitialspace=True, engine='python')
|
39 |
+
|
40 |
+
# Drop empty columns due to leading/trailing pipes
|
41 |
+
df = df.dropna(axis=1, how="all")
|
42 |
+
df.columns = [c.strip() for c in df.columns]
|
43 |
+
df = df.dropna(axis=0, how="all")
|
44 |
+
|
45 |
+
# Extract header and values
|
46 |
+
elements = df.columns[1:]
|
47 |
+
df.index = df[df.columns[0]]
|
48 |
+
df = df.drop(df.columns[0], axis=1)
|
49 |
+
|
50 |
+
# Check commutativity: a*b == b*a
|
51 |
+
counterexample_elements = set()
|
52 |
+
for x in elements:
|
53 |
+
for y in elements:
|
54 |
+
if df.loc[x, y] != df.loc[y, x]:
|
55 |
+
counterexample_elements.add(x)
|
56 |
+
counterexample_elements.add(y)
|
57 |
+
|
58 |
+
return ", ".join(sorted(counterexample_elements))
|
tools/retrievers.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.tools import tool
|
2 |
+
from langchain_community.document_loaders import ArxivLoader
|
3 |
+
from langchain_community.retrievers import WikipediaRetriever
|
4 |
+
|
5 |
+
@tool
|
6 |
+
def arxiv_search(query: str) -> str:
|
7 |
+
"""
|
8 |
+
Search Arxiv for academic papers based on a query and return summaries of top results.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
query (str): The search query for Arxiv.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
str: Summary of the top few relevant papers from Arxiv.
|
15 |
+
"""
|
16 |
+
try:
|
17 |
+
loader = ArxivLoader(query=query, load_max_docs=2)
|
18 |
+
documents = loader.load()
|
19 |
+
|
20 |
+
if not documents:
|
21 |
+
return "No relevant papers found on Arxiv."
|
22 |
+
|
23 |
+
# Format and return top paper summaries
|
24 |
+
results = []
|
25 |
+
for doc in documents:
|
26 |
+
title = doc.metadata.get("Title", "No Title")
|
27 |
+
published = doc.metadata.get("Published", "Unknown date")
|
28 |
+
url = doc.metadata.get("entry_id", "No URL")
|
29 |
+
summary = doc.page_content[:500] # limit summary length
|
30 |
+
|
31 |
+
results.append(f"Title: {title}\nPublished: {published}\nURL: {url}\nSummary: {summary}\n")
|
32 |
+
|
33 |
+
return "\n---\n".join(results)
|
34 |
+
|
35 |
+
except Exception as e:
|
36 |
+
return f"An error occurred while searching Arxiv: {str(e)}"
|
37 |
+
|
38 |
+
@tool
|
39 |
+
def wikipedia_search(query: str) -> str:
|
40 |
+
"""
|
41 |
+
Search Wikipedia for a given query and return a summary of the top result.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
query (str): The search term.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
str: A summary of the most relevant Wikipedia entry.
|
48 |
+
"""
|
49 |
+
wikipedia_retriever = WikipediaRetriever(load_max_docs=1)
|
50 |
+
|
51 |
+
documents = wikipedia_retriever.get_relevant_documents(query)
|
52 |
+
if not documents:
|
53 |
+
return "No relevant Wikipedia articles found."
|
54 |
+
|
55 |
+
formatted_search_docs = "\n\n---\n\n".join(
|
56 |
+
[
|
57 |
+
f'<Document source="{doc.metadata["source"]}" title="{doc.metadata.get("title", "")}"/>\n{doc.metadata["summary"]}\n</Document>'
|
58 |
+
for doc in documents
|
59 |
+
])
|
60 |
+
|
61 |
+
# Return the content of the top document
|
62 |
+
return formatted_search_docs
|
tools/utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.text_splitter import TextSplitter
|
2 |
+
from langchain.schema import Document
|
3 |
+
|
4 |
+
class StructureAwareTextSplitter(TextSplitter):
|
5 |
+
"""
|
6 |
+
A custom text splitter that creates context-aware document chunks from structured HTML content.
|
7 |
+
|
8 |
+
This splitter buffers paragraphs, lists, and tables together into chunks up to a specified size,
|
9 |
+
preserving section headers and content structure. Tables are combined with surrounding content
|
10 |
+
when possible, but split into their own chunk if too large. Useful for web page or wiki-style
|
11 |
+
content where structure and context are important for downstream retrieval or LLM tasks.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
chunk_size (int): Maximum number of words per chunk.
|
15 |
+
chunk_overlap (int): Number of words to overlap between chunks (not currently used).
|
16 |
+
|
17 |
+
Methods:
|
18 |
+
split_text(text): Dummy implementation to satisfy the abstract base class.
|
19 |
+
split_documents(structured_blocks, metadata=None): Splits structured content blocks into
|
20 |
+
Document objects with preserved section headers and types.
|
21 |
+
"""
|
22 |
+
def __init__(self, chunk_size=500, chunk_overlap=50):
|
23 |
+
super().__init__(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
24 |
+
|
25 |
+
#TODO: To be implemented
|
26 |
+
def split_text(self, text):
|
27 |
+
# Dummy implementation to satisfy the abstract base class
|
28 |
+
return [text]
|
29 |
+
|
30 |
+
def split_documents(self, structured_blocks, metadata=None):
|
31 |
+
current_chunk = ""
|
32 |
+
current_words_cnt = 0
|
33 |
+
current_header = ""
|
34 |
+
documents = []
|
35 |
+
|
36 |
+
def add_document(content, header, type_):
|
37 |
+
documents.append(Document(
|
38 |
+
page_content=content.strip(),
|
39 |
+
metadata={
|
40 |
+
"section_header": header,
|
41 |
+
"type": type_,
|
42 |
+
**(metadata or {})
|
43 |
+
}
|
44 |
+
))
|
45 |
+
|
46 |
+
for block in structured_blocks:
|
47 |
+
type_ = block['type']
|
48 |
+
if type_ == 'header':
|
49 |
+
current_header = block['text']
|
50 |
+
|
51 |
+
elif type_ in ['paragraph', 'list']:
|
52 |
+
if type_ == 'paragraph':
|
53 |
+
text = block['text']
|
54 |
+
else: # list
|
55 |
+
text = "\n".join(block['items']) + "\n"
|
56 |
+
words_cnt = len(text.split())
|
57 |
+
if current_words_cnt + words_cnt <= self._chunk_size:
|
58 |
+
current_chunk += text + "\n"
|
59 |
+
current_words_cnt += words_cnt
|
60 |
+
else:
|
61 |
+
add_document(f"{current_header}\n\n{current_chunk}", current_header, type_)
|
62 |
+
current_chunk = text + "\n"
|
63 |
+
current_words_cnt = words_cnt
|
64 |
+
|
65 |
+
elif type_ == 'table':
|
66 |
+
table_text = f"{current_header} [Table]\n\n{block['text']}\n"
|
67 |
+
words_cnt = len(table_text.split())
|
68 |
+
# Try to buffer table with current chunk if possible
|
69 |
+
if current_words_cnt + words_cnt <= self._chunk_size:
|
70 |
+
current_chunk += table_text
|
71 |
+
current_words_cnt += words_cnt
|
72 |
+
else:
|
73 |
+
# If current_chunk is not empty, flush it first
|
74 |
+
if current_chunk.strip():
|
75 |
+
add_document(f"{current_header}\n\n{current_chunk}", current_header, 'mixed')
|
76 |
+
# If table itself is too big, split it alone
|
77 |
+
if words_cnt > self._chunk_size:
|
78 |
+
add_document(table_text, current_header, 'table')
|
79 |
+
current_chunk = ""
|
80 |
+
current_words_cnt = 0
|
81 |
+
else:
|
82 |
+
current_chunk = table_text
|
83 |
+
current_words_cnt = words_cnt
|
84 |
+
|
85 |
+
if current_chunk.strip():
|
86 |
+
add_document(f"{current_header}\n\n{current_chunk}", current_header, 'mixed')
|
87 |
+
|
88 |
+
return documents
|
tools/web_search.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from io import StringIO
|
5 |
+
from bs4 import BeautifulSoup
|
6 |
+
from langchain_core.tools import tool
|
7 |
+
from duckduckgo_search import DDGS
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
10 |
+
from tools.utils import StructureAwareTextSplitter
|
11 |
+
|
12 |
+
TOP_K = 5
|
13 |
+
MAX_RESULTS = 2
|
14 |
+
UNWANTED_TAGS = ['nav', 'header', 'footer', 'aside', 'form', 'script', 'style']
|
15 |
+
TAGS_TO_KEEP = ['h1', 'h2', 'h3', 'p', 'ul', 'ol', 'table']
|
16 |
+
|
17 |
+
|
18 |
+
def _format_table_to_string(table_html):
|
19 |
+
"""
|
20 |
+
Convert an HTML table to a markdown-style string representation.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
table_html (str): HTML string of the table.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: Table formatted as a markdown-style string, or a message if parsing fails.
|
27 |
+
"""
|
28 |
+
try:
|
29 |
+
df = pd.read_html(StringIO(table_html))[0]
|
30 |
+
except:
|
31 |
+
return ["[Table could not be parsed]"]
|
32 |
+
|
33 |
+
if df.empty:
|
34 |
+
return None
|
35 |
+
|
36 |
+
table_str = "|"
|
37 |
+
# Put column headers
|
38 |
+
for col in df.columns:
|
39 |
+
table_str += f" {col} |"
|
40 |
+
table_str += "\n"
|
41 |
+
|
42 |
+
# Put rows
|
43 |
+
for _, row in df.iterrows():
|
44 |
+
table_str += "|"
|
45 |
+
for col, val in row.items():
|
46 |
+
table_str += f" {val} |"
|
47 |
+
table_str += "\n"
|
48 |
+
|
49 |
+
return table_str
|
50 |
+
|
51 |
+
def _extract_list(tag, level=0):
|
52 |
+
"""
|
53 |
+
Recursively extract nested HTML lists (<ul> or <ol>) into a formatted text list.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
tag (bs4.element.Tag): The <ul> or <ol> BeautifulSoup tag to extract.
|
57 |
+
level (int): The current nesting level (used for indentation and prefixing).
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
list[str]: List of formatted strings representing the list items, preserving nesting.
|
61 |
+
"""
|
62 |
+
items = []
|
63 |
+
if tag.name not in ["ul", "ol"]:
|
64 |
+
return items
|
65 |
+
|
66 |
+
is_ordered = tag.name == "ol"
|
67 |
+
# Determine prefix style
|
68 |
+
if is_ordered:
|
69 |
+
# Use numbers for top-level, letters for nested
|
70 |
+
if level == 0:
|
71 |
+
item_prefix = lambda idx: f"{idx+1}."
|
72 |
+
else:
|
73 |
+
# a., b., c., ...
|
74 |
+
item_prefix = lambda idx: f"{chr(97+idx)}."
|
75 |
+
else:
|
76 |
+
item_prefix = lambda idx: "-"
|
77 |
+
|
78 |
+
for idx, li in enumerate(tag.find_all("li", recursive=False)):
|
79 |
+
# Get the text before any nested list
|
80 |
+
text = li.find(text=True, recursive=False)
|
81 |
+
text = text.strip() if text else ""
|
82 |
+
# Check for nested lists
|
83 |
+
nested = li.find(["ul", "ol"], recursive=False)
|
84 |
+
if nested:
|
85 |
+
nested_items = _extract_list(nested, level+1)
|
86 |
+
if text:
|
87 |
+
items.append(f"{' '*level}{item_prefix(idx)} {text}")
|
88 |
+
items.extend([f"{' '*(level+1)}{line}" for line in nested_items])
|
89 |
+
else:
|
90 |
+
items.append(f"{' '*level}{item_prefix(idx)} {text}")
|
91 |
+
return items
|
92 |
+
|
93 |
+
def _parse_structured_content(soup):
|
94 |
+
"""
|
95 |
+
Parse the main content of a BeautifulSoup HTML document into structured blocks.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
soup (bs4.BeautifulSoup): Parsed HTML document.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
list[dict]: List of structured content blocks (headers, paragraphs, lists, tables).
|
102 |
+
"""
|
103 |
+
content = []
|
104 |
+
|
105 |
+
for tag in soup.find_all(TAGS_TO_KEEP):
|
106 |
+
if tag.name in ['h1', 'h2', 'h3']:
|
107 |
+
content.append({'type': 'header', 'level': tag.name, 'text': tag.get_text(strip=True)})
|
108 |
+
elif tag.name == 'p':
|
109 |
+
content.append({'type': 'paragraph', 'text': tag.get_text(strip=True)})
|
110 |
+
elif tag.name in ['ul', 'ol']:
|
111 |
+
if tag.find_parent(['ul', 'ol']) is None:
|
112 |
+
items = _extract_list(tag)
|
113 |
+
content.append({'type': 'list', 'items': items})
|
114 |
+
elif tag.name == 'table':
|
115 |
+
content.append({'type': 'table', 'html': str(tag)})
|
116 |
+
|
117 |
+
return content
|
118 |
+
|
119 |
+
@tool
|
120 |
+
def web_search(query: str) -> str:
|
121 |
+
"""
|
122 |
+
Perform a web search using DuckDuckGo.
|
123 |
+
|
124 |
+
This tool is acting as live data RAG (Retrieval-Augmented Generation) tool.
|
125 |
+
It's useful for retrieving relevant information or obtaining domain knowledge
|
126 |
+
in a specific area, such as mathematics, science, games, etc.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
query (str): The search query.
|
130 |
+
Returns:
|
131 |
+
chunks (str): Concatenated string of most relevant chunks.
|
132 |
+
"""
|
133 |
+
|
134 |
+
# ----- STEP 1: Find the most relevant webpages
|
135 |
+
results = DDGS(timeout=30).text(query, max_results=MAX_RESULTS)
|
136 |
+
|
137 |
+
urls = [r['href'] for r in results if 'href' in r]
|
138 |
+
|
139 |
+
all_chunks = []
|
140 |
+
for url in urls:
|
141 |
+
try:
|
142 |
+
response = requests.get(url)
|
143 |
+
html = response.text
|
144 |
+
except Exception as e:
|
145 |
+
return f"Error fetching URL {url}: {str(e)}"
|
146 |
+
|
147 |
+
# ----- STEP 2: Parse and clean the HTML content
|
148 |
+
soup = BeautifulSoup(html, "html.parser")
|
149 |
+
|
150 |
+
# Remove unwanted tags before parsing structured content
|
151 |
+
for tag in soup.find_all(UNWANTED_TAGS):
|
152 |
+
tag.decompose()
|
153 |
+
|
154 |
+
structured_content = _parse_structured_content(soup)
|
155 |
+
|
156 |
+
# ----- STEP 3: Format tables to string representation
|
157 |
+
for item in structured_content:
|
158 |
+
if item['type'] == 'table':
|
159 |
+
table_str = _format_table_to_string(item['html'])
|
160 |
+
if table_str:
|
161 |
+
item['text'] = table_str
|
162 |
+
else:
|
163 |
+
# Skip empty or unparseable tables
|
164 |
+
structured_content.remove(item)
|
165 |
+
|
166 |
+
# ----- STEP 4: Split structured content into chunks
|
167 |
+
splitter = StructureAwareTextSplitter(chunk_size=500, chunk_overlap=50)
|
168 |
+
documents = splitter.split_documents(structured_content)
|
169 |
+
|
170 |
+
all_chunks.extend([
|
171 |
+
f"\n\n----- CHUNK {i} (url: {url})-----\n\n" + doc.page_content
|
172 |
+
for i, doc in enumerate(documents)
|
173 |
+
])
|
174 |
+
|
175 |
+
# ----- STEP 5: Make embeddings
|
176 |
+
model = SentenceTransformer("all-MiniLM-L6-v2") # Small & fast
|
177 |
+
embeddings = model.encode(all_chunks)
|
178 |
+
|
179 |
+
embedded_query = model.encode(query)
|
180 |
+
|
181 |
+
# ----- STEP 6: Calculate cosine similarity
|
182 |
+
# Reshape query for pairwise comparison
|
183 |
+
embedded_query = np.array(embedded_query).reshape(1, -1)
|
184 |
+
embeddings = np.array(embeddings)
|
185 |
+
|
186 |
+
# Compute cosine similarities
|
187 |
+
similarities = cosine_similarity(embedded_query, embeddings)[0] # Shape: (n_chunks,)
|
188 |
+
|
189 |
+
# Get most similar chunks
|
190 |
+
top_indices = similarities.argsort()[-TOP_K:][::-1]
|
191 |
+
|
192 |
+
# output in a file the top chunks
|
193 |
+
# with open("test_output/top_chunks.txt", "w", encoding="utf-8") as f:
|
194 |
+
# for c in all_chunks:
|
195 |
+
# f.write(c)
|
196 |
+
|
197 |
+
return "".join([all_chunks[idx] for idx in top_indices])
|