Pycrolis commited on
Commit
e049457
·
2 Parent(s): d8b8674 508a421

Merge branch 'feat/add-tools'

Browse files
README.md CHANGED
@@ -15,10 +15,21 @@ short_description: Agent for the final hands-on assignment of the Agents course
15
 
16
  ## Requirements
17
 
18
- To run this project, you need to have an OpenAI API key. Set your `OPENAI_API_KEY` as an environment variable:
 
 
19
 
 
 
 
 
 
 
20
  ```bash
21
  export OPENAI_API_KEY='your-api-key'
 
22
  ```
23
 
24
- You can get your OpenAI API key from [here](https://platform.openai.com/account/api-keys).
 
 
 
15
 
16
  ## Requirements
17
 
18
+ ### Prerequisites
19
+ - Python 3.x
20
+ - Virtual environment (recommended)
21
 
22
+ ### API Keys
23
+ This project requires the following API keys:
24
+ - OpenAI API key
25
+ - Tavily API key
26
+
27
+ Set them as environment variables:
28
  ```bash
29
  export OPENAI_API_KEY='your-api-key'
30
+ export TAVILY_API_KEY='your-api-key'
31
  ```
32
 
33
+ You can get the required API keys here:
34
+ - OpenAI API key: [OpenAI Platform](https://platform.openai.com/account/api-keys)
35
+ - Tavily API key: [Tavily](https://tavily.com)
ShrewdAgent.py CHANGED
@@ -4,6 +4,7 @@ from typing import TypedDict, Annotated, Optional, Any, Callable, Sequence, Unio
4
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
5
  from langchain_core.tools import BaseTool
6
  from langchain_openai import ChatOpenAI
 
7
  from langgraph.constants import START
8
  from langgraph.errors import GraphRecursionError
9
  from langgraph.graph import add_messages, StateGraph
@@ -12,6 +13,11 @@ from langgraph.pregel import PregelProtocol
12
  from loguru import logger
13
  from pydantic import SecretStr
14
 
 
 
 
 
 
15
 
16
  class AgentState(TypedDict):
17
  messages: Annotated[list[AnyMessage], add_messages]
@@ -32,7 +38,13 @@ class ShrewdAgent:
32
  Important: Your final output must be only a number or a short phrase, with no additional text or explanation."""
33
 
34
  def __init__(self):
35
- self.tools = []
 
 
 
 
 
 
36
  self.llm = ChatOpenAI(
37
  model="gpt-4o-mini",
38
  temperature=0,
 
4
  from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
5
  from langchain_core.tools import BaseTool
6
  from langchain_openai import ChatOpenAI
7
+ from langchain_tavily import TavilySearch
8
  from langgraph.constants import START
9
  from langgraph.errors import GraphRecursionError
10
  from langgraph.graph import add_messages, StateGraph
 
13
  from loguru import logger
14
  from pydantic import SecretStr
15
 
16
+ from tools.produce_classifier import produce_classifier
17
+ from tools.web_page_information_extractor import web_page_information_extractor
18
+ from tools.wikipedia_search import wikipedia_search
19
+ from tools.youtube_transcript import youtube_transcript
20
+
21
 
22
  class AgentState(TypedDict):
23
  messages: Annotated[list[AnyMessage], add_messages]
 
38
  Important: Your final output must be only a number or a short phrase, with no additional text or explanation."""
39
 
40
  def __init__(self):
41
+ self.tools = [
42
+ TavilySearch(),
43
+ wikipedia_search,
44
+ web_page_information_extractor,
45
+ youtube_transcript,
46
+ produce_classifier,
47
+ ]
48
  self.llm = ChatOpenAI(
49
  model="gpt-4o-mini",
50
  temperature=0,
requirements.txt CHANGED
@@ -5,4 +5,9 @@ langchain-core~=0.3.60
5
  langchain-openai~=0.3.17
6
  langgraph~=0.4.5
7
  loguru~=0.7.3
8
- pydantic~=2.11.4
 
 
 
 
 
 
5
  langchain-openai~=0.3.17
6
  langgraph~=0.4.5
7
  loguru~=0.7.3
8
+ pydantic~=2.11.4
9
+ html2text~=2025.4.15
10
+ beautifulsoup4~=4.13.4
11
+ readability-lxml~=0.8.4.1
12
+ youtube-transcript-api~=1.0.3
13
+ wikipedia~=1.4.0
tools/produce_classifier.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain_core.tools import tool
4
+ from langchain_openai import ChatOpenAI
5
+ from loguru import logger
6
+ from pydantic import SecretStr
7
+
8
+
9
+ @tool("produce_classifier_tool", parse_docstring=True)
10
+ def produce_classifier(food_name: str) -> str:
11
+ """
12
+ Classifies a food item as either 'fruit' or 'vegetable' from a botanical perspective.
13
+
14
+ Args:
15
+ food_name (str): The name of the food item to classify.
16
+
17
+ Returns:
18
+ str: The classification of the food item, either 'fruit' or 'vegetable'.
19
+ """
20
+ logger.info(f"use produce_classifier_tool with param: {food_name}")
21
+ chat = ChatOpenAI(
22
+ model="gpt-4o-mini",
23
+ temperature=0,
24
+ api_key = SecretStr(os.environ['OPENAI_API_KEY'])
25
+ )
26
+
27
+ prompt = (f"From a botanical perspective, classify {food_name} as either 'fruit' or 'vegetable'. "
28
+ #f"If it's not a produce name, classify as 'neither'. "
29
+ f"Respond only with a JSON in this exact format: "
30
+ f"{{\"name\": \"{food_name}\", \"kind\": \"[classification]\"}}, "
31
+ f"where [classification] should be replaced with just 'fruit' or 'vegetable'." #, or 'neither'. "
32
+ f"No other text or explanation.")
33
+ return chat.invoke(prompt).content
34
+
35
+ def _print_produce_kind(food_name: str):
36
+ print(f'{food_name}: {produce_classifier.invoke(food_name)}')
37
+
38
+ if __name__ == "__main__":
39
+ #_print_produce_kind("orange")
40
+ #_print_produce_kind("sweet potatoes")
41
+ #_print_produce_kind("bell pepper")
42
+ #_print_produce_kind("egg")
43
+ #_print_produce_kind("table")
44
+ _print_produce_kind("zucchini")
tools/web_page_information_extractor.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import StringIO
3
+
4
+ import html2text
5
+ import pandas as pd
6
+ import requests
7
+ from bs4 import BeautifulSoup
8
+ from langchain_core.messages import SystemMessage, HumanMessage
9
+ from langchain_core.tools import tool
10
+ from langchain_openai import ChatOpenAI
11
+ from loguru import logger
12
+ from pydantic import SecretStr
13
+ from readability import Document
14
+
15
+
16
+ @tool("web_page_information_extractor_tool", parse_docstring=True)
17
+ def web_page_information_extractor(url: str, request: str) -> str:
18
+ """
19
+ Extracts specific information from a web page based on the user's request.
20
+
21
+ This function uses a language model to extract information from the content
22
+ of a web page specified by the URL. The user's request specifies the type of
23
+ information to be extracted. The function returns the extracted information as
24
+ a JSON string.
25
+
26
+ Args:
27
+ url (str): The URL of the web page to extract information from.
28
+ request (str): The user's request describing the information to extract.
29
+
30
+ Returns:
31
+ str: The extracted information in JSON format.
32
+ """
33
+ logger.info(f"use web_page_information_extractor with param: url:{url}, request:{request}")
34
+ response = requests.get(url, headers={"User-Agent": "Mozilla/5.0"})
35
+ response.raise_for_status() # Raises HTTPError for bad responses
36
+ html = response.text
37
+ doc = Document(html)
38
+ cleaned_html = doc.summary()
39
+
40
+ soup = BeautifulSoup(cleaned_html, "html.parser")
41
+
42
+ # Get tables
43
+ tables = soup.find_all('table', class_='wikitable')
44
+ tables_text = ""
45
+
46
+ for i, table in enumerate(tables, 1):
47
+ # Find the nearest preceding h2 or h3 header
48
+ header = table.find_previous(['h2', 'h3'])
49
+ section_title = header.get_text().strip() if header else "Untitled Section"
50
+ try:
51
+ # Convert table to pandas DataFrame using StringIO
52
+ table_html = str(table).replace('\n', '') # Remove newlines for better parsing
53
+ df = pd.read_html(StringIO(table_html))[0]
54
+
55
+ # Format the table with section title, context, and clean layout
56
+ tables_text += f"\nSection: {section_title}\n"
57
+ tables_text += "=" * 40 + "\n"
58
+ tables_text += df.to_string(index=False) + "\n\n"
59
+ except Exception as e:
60
+ tables_text += f"\nError processing table in section {section_title}: {str(e)}\n"
61
+ continue
62
+
63
+ # Step 3: Convert HTML to Markdown
64
+ markdown_converter = html2text.HTML2Text()
65
+ markdown_converter.ignore_links = False
66
+ markdown_converter.bypass_tables = False
67
+ markdown_converter.ignore_images = True # optional
68
+ markdown_converter.body_width = 0 # don't wrap lines
69
+
70
+ text = markdown_converter.handle(cleaned_html)
71
+ if tables_text:
72
+ text += f'Tables:\n{tables_text}'
73
+
74
+ logger.debug(f"web_page_information_extractor text: {text}")
75
+
76
+ chat = ChatOpenAI(
77
+ model="gpt-4o-mini",
78
+ temperature=0,
79
+ api_key=SecretStr(os.environ['OPENAI_API_KEY'])
80
+ )
81
+
82
+ system_message = "You are an expert information extraction system. Respond ONLY with valid JSON based on the user's request."
83
+ extraction_user_prompt = f"""From the text below:\n\"\"\"\n{text}\n\"\"\"\n\nExtract the following: "{request}"."""
84
+
85
+ extracted_information = chat.invoke([
86
+ SystemMessage(system_message),
87
+ HumanMessage(extraction_user_prompt)
88
+ ])
89
+ return extracted_information.content
90
+
91
+
92
+ if __name__ == "__main__":
93
+ # result = web_page_information_extractor.invoke(
94
+ # {"url": "https://en.wikipedia.org/wiki/Python_(programming_language)",
95
+ # "request": "What are changes introduced in Python 3.11"})
96
+ # print(result)
97
+
98
+ result = web_page_information_extractor.invoke(
99
+ {"url": "https://en.wikipedia.org/wiki/1928_Summer_Olympics",
100
+ "request": "List of countries and number of athletes at the 1928 Summer Olympics"})
101
+ print(result)
102
+
103
+ # result = web_page_information_extractor.invoke(
104
+ # {"url": "https://chem.libretexts.org/Courses/Chabot_College/Introduction_to_General_Organic_and_Biochemistry/01%3A_Chemistry_in_our_Lives/1.E%3A_Exercises",
105
+ # "request": "What is the surname of the equine veterinarian mentioned"})
106
+ # print(result)
tools/wikipedia_search.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import wikipedia
2
+ from langchain_core.tools import tool
3
+ from loguru import logger
4
+
5
+
6
+ @tool("wikipedia_search_tool", parse_docstring=True)
7
+ def wikipedia_search(query: str) -> str:
8
+ """
9
+ Searches Wikipedia for the given query.
10
+
11
+ Args:
12
+ query (str): The search query to look up on Wikipedia.
13
+
14
+ Returns:
15
+ str: A formatted string with the search results, page title and url.
16
+ """
17
+ logger.info(f"use wikipedia_search_tool with param: {query}")
18
+
19
+ search_results = wikipedia.search(query, results=5)
20
+
21
+ if not search_results:
22
+ return "No results found for the query."
23
+
24
+ result_text = ""
25
+ try:
26
+ for i, title in enumerate(search_results, 1):
27
+ page = wikipedia.page(search_results[i - 1], auto_suggest=False)
28
+ result_text += f"{i}. [{title}]({page.url})\n"
29
+
30
+ return result_text
31
+
32
+ except wikipedia.DisambiguationError as e:
33
+ return f"Disambiguation page found. Possible matches:\n{'\n'.join(e.options)}"
34
+ except wikipedia.PageError as e:
35
+ return f"Page not found. Try another search term."
36
+ except Exception as e:
37
+ return f"An error occurred: {str(e)}"
38
+
39
+
40
+ if __name__ == "__main__":
41
+ print(wikipedia_search.invoke("Mercedes Sosa discography"))
tools/youtube_transcript.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from loguru import logger
3
+ from youtube_transcript_api import YouTubeTranscriptApi, FetchedTranscript
4
+
5
+
6
+ @tool("youtube_transcript_tool", parse_docstring=True)
7
+ def youtube_transcript(video_id: str) -> str:
8
+ """
9
+ Fetches the transcript of a YouTube video using its video ID.
10
+
11
+ The video ID must be provided to successfully fetch the transcript.
12
+
13
+ Args:
14
+ video_id (str): The unique identifier of a YouTube video. You can retrieve the video_id from the URL of the video. For example, with the URL https://www.youtube.com/watch?v=12345 the video_id is 12345.
15
+
16
+ Returns:
17
+ FetchedTranscript: The transcript of the specified YouTube video.
18
+
19
+ Raises:
20
+ Any exceptions related to YouTubeTranscriptApi when a problem
21
+ occurs during fetching the transcript.
22
+ """
23
+ logger.info(f"use youtube_transcript with param: {video_id}")
24
+ transcript = YouTubeTranscriptApi().fetch(video_id).to_raw_data()
25
+
26
+ bullet_points = '\n'.join(f"- {entry['text']}" for entry in transcript)
27
+
28
+ return bullet_points
29
+
30
+ if __name__ == "__main__":
31
+ transcript = youtube_transcript.invoke("1htKBjuUWec")
32
+ print(transcript)