Pycrolis commited on
Commit
281ee36
·
1 Parent(s): 494ead3

feat: add ShrewdAgent implementation for AI-assisted tool integration

Browse files
Files changed (2) hide show
  1. ShrewdAgent.py +91 -0
  2. requirements.txt +8 -2
ShrewdAgent.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TypedDict, Annotated, Optional, Any, Callable, Sequence, Union
3
+
4
+ from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
5
+ from langchain_core.tools import BaseTool
6
+ from langchain_openai import ChatOpenAI
7
+ from langgraph.constants import START
8
+ from langgraph.errors import GraphRecursionError
9
+ from langgraph.graph import add_messages, StateGraph
10
+ from langgraph.prebuilt import ToolNode, tools_condition
11
+ from langgraph.pregel import PregelProtocol
12
+ from loguru import logger
13
+ from pydantic import SecretStr
14
+
15
+
16
+ class AgentState(TypedDict):
17
+ messages: Annotated[list[AnyMessage], add_messages]
18
+
19
+
20
+ class ShrewdAgent:
21
+ message_system = """You are a general AI assistant equipped with a suite of external tools. Your task is to
22
+ answer the following question as accurately and helpfully as possible by using the tools
23
+ provided. Do not write or execute code yourself. For any operation requiring computation,
24
+ data retrieval, or external access, explicitly invoke the appropriate tool.
25
+
26
+ Follow these guidelines:"
27
+ - Clearly explain your reasoning step by step.
28
+ - Justify your choice of tool(s) at each step.
29
+ - If multiple interpretations are possible, outline them and explain your reasoning for selecting one.
30
+ - If the answer requires external data or inference, retrieve or deduce it via the available tools.
31
+
32
+ Important: Your final output must be only a number or a short phrase, with no additional text or explanation."""
33
+
34
+ def __init__(self):
35
+ self.tools = []
36
+ self.llm = ChatOpenAI(
37
+ model="gpt-4o-mini",
38
+ temperature=0,
39
+ api_key=SecretStr(os.environ['OPENAI_API_KEY'])
40
+ ).bind_tools(self.tools)
41
+
42
+ def assistant_node(state: AgentState):
43
+ return {
44
+ "messages": [self.llm.invoke(state["messages"])],
45
+ }
46
+
47
+ self.agent = _build_state_graph(AgentState, assistant_node, self.tools)
48
+ logger.info("ShrewdAgent initialized.")
49
+
50
+ def __call__(self, question: str) -> str:
51
+ logger.info(f"Agent received question: {question}")
52
+ accumulated_response = []
53
+ try:
54
+ for chunk in self.agent.stream(
55
+ {"messages": [
56
+ SystemMessage(self.message_system),
57
+ HumanMessage(question, )
58
+ ]},
59
+ {"recursion_limit": 12},
60
+ ):
61
+ assistant = chunk.get("assistant")
62
+ if assistant:
63
+ logger.debug(f"{assistant.get('messages')[0].pretty_repr()}")
64
+ tools = chunk.get("tools")
65
+ if tools:
66
+ logger.debug(f"{tools.get('messages')[0].pretty_repr()}")
67
+ accumulated_response.append(chunk)
68
+
69
+ except GraphRecursionError as e:
70
+ logger.error(f"GraphRecursionError: {e}")
71
+
72
+ final_answer = "I couldn't find the answer"
73
+ if accumulated_response[-1].get("assistant"):
74
+ final_answer = accumulated_response[-1]["assistant"]['messages'][-1].content
75
+ logger.info(f"Agent returning answer: {final_answer}")
76
+ return final_answer
77
+
78
+
79
+ def _build_state_graph(
80
+ state_schema: Optional[type[Any]],
81
+ assistant: Callable,
82
+ tools: Sequence[Union[BaseTool, Callable]]) -> PregelProtocol: # CompiledStateGraph:
83
+
84
+ return (StateGraph(state_schema)
85
+ .add_node("assistant", assistant)
86
+ .add_node("tools", ToolNode(tools))
87
+ .add_edge(START, "assistant")
88
+ .add_conditional_edges("assistant", tools_condition)
89
+ .add_edge("tools", "assistant")
90
+ .compile()
91
+ )
requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
- gradio
2
- requests
 
 
 
 
 
 
 
1
+ gradio~=5.29.1
2
+ requests~=2.32.3
3
+ pandas~=2.2.3
4
+ langchain-core~=0.3.60
5
+ langchain-openai~=0.3.17
6
+ langgraph~=0.4.5
7
+ loguru~=0.7.3
8
+ pydantic~=2.11.4