Spaces:
Sleeping
Sleeping
Pycrolis
commited on
Commit
·
281ee36
1
Parent(s):
494ead3
feat: add ShrewdAgent implementation for AI-assisted tool integration
Browse files- ShrewdAgent.py +91 -0
- requirements.txt +8 -2
ShrewdAgent.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import TypedDict, Annotated, Optional, Any, Callable, Sequence, Union
|
3 |
+
|
4 |
+
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
|
5 |
+
from langchain_core.tools import BaseTool
|
6 |
+
from langchain_openai import ChatOpenAI
|
7 |
+
from langgraph.constants import START
|
8 |
+
from langgraph.errors import GraphRecursionError
|
9 |
+
from langgraph.graph import add_messages, StateGraph
|
10 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
11 |
+
from langgraph.pregel import PregelProtocol
|
12 |
+
from loguru import logger
|
13 |
+
from pydantic import SecretStr
|
14 |
+
|
15 |
+
|
16 |
+
class AgentState(TypedDict):
|
17 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
18 |
+
|
19 |
+
|
20 |
+
class ShrewdAgent:
|
21 |
+
message_system = """You are a general AI assistant equipped with a suite of external tools. Your task is to
|
22 |
+
answer the following question as accurately and helpfully as possible by using the tools
|
23 |
+
provided. Do not write or execute code yourself. For any operation requiring computation,
|
24 |
+
data retrieval, or external access, explicitly invoke the appropriate tool.
|
25 |
+
|
26 |
+
Follow these guidelines:"
|
27 |
+
- Clearly explain your reasoning step by step.
|
28 |
+
- Justify your choice of tool(s) at each step.
|
29 |
+
- If multiple interpretations are possible, outline them and explain your reasoning for selecting one.
|
30 |
+
- If the answer requires external data or inference, retrieve or deduce it via the available tools.
|
31 |
+
|
32 |
+
Important: Your final output must be only a number or a short phrase, with no additional text or explanation."""
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
self.tools = []
|
36 |
+
self.llm = ChatOpenAI(
|
37 |
+
model="gpt-4o-mini",
|
38 |
+
temperature=0,
|
39 |
+
api_key=SecretStr(os.environ['OPENAI_API_KEY'])
|
40 |
+
).bind_tools(self.tools)
|
41 |
+
|
42 |
+
def assistant_node(state: AgentState):
|
43 |
+
return {
|
44 |
+
"messages": [self.llm.invoke(state["messages"])],
|
45 |
+
}
|
46 |
+
|
47 |
+
self.agent = _build_state_graph(AgentState, assistant_node, self.tools)
|
48 |
+
logger.info("ShrewdAgent initialized.")
|
49 |
+
|
50 |
+
def __call__(self, question: str) -> str:
|
51 |
+
logger.info(f"Agent received question: {question}")
|
52 |
+
accumulated_response = []
|
53 |
+
try:
|
54 |
+
for chunk in self.agent.stream(
|
55 |
+
{"messages": [
|
56 |
+
SystemMessage(self.message_system),
|
57 |
+
HumanMessage(question, )
|
58 |
+
]},
|
59 |
+
{"recursion_limit": 12},
|
60 |
+
):
|
61 |
+
assistant = chunk.get("assistant")
|
62 |
+
if assistant:
|
63 |
+
logger.debug(f"{assistant.get('messages')[0].pretty_repr()}")
|
64 |
+
tools = chunk.get("tools")
|
65 |
+
if tools:
|
66 |
+
logger.debug(f"{tools.get('messages')[0].pretty_repr()}")
|
67 |
+
accumulated_response.append(chunk)
|
68 |
+
|
69 |
+
except GraphRecursionError as e:
|
70 |
+
logger.error(f"GraphRecursionError: {e}")
|
71 |
+
|
72 |
+
final_answer = "I couldn't find the answer"
|
73 |
+
if accumulated_response[-1].get("assistant"):
|
74 |
+
final_answer = accumulated_response[-1]["assistant"]['messages'][-1].content
|
75 |
+
logger.info(f"Agent returning answer: {final_answer}")
|
76 |
+
return final_answer
|
77 |
+
|
78 |
+
|
79 |
+
def _build_state_graph(
|
80 |
+
state_schema: Optional[type[Any]],
|
81 |
+
assistant: Callable,
|
82 |
+
tools: Sequence[Union[BaseTool, Callable]]) -> PregelProtocol: # CompiledStateGraph:
|
83 |
+
|
84 |
+
return (StateGraph(state_schema)
|
85 |
+
.add_node("assistant", assistant)
|
86 |
+
.add_node("tools", ToolNode(tools))
|
87 |
+
.add_edge(START, "assistant")
|
88 |
+
.add_conditional_edges("assistant", tools_condition)
|
89 |
+
.add_edge("tools", "assistant")
|
90 |
+
.compile()
|
91 |
+
)
|
requirements.txt
CHANGED
@@ -1,2 +1,8 @@
|
|
1 |
-
gradio
|
2 |
-
requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio~=5.29.1
|
2 |
+
requests~=2.32.3
|
3 |
+
pandas~=2.2.3
|
4 |
+
langchain-core~=0.3.60
|
5 |
+
langchain-openai~=0.3.17
|
6 |
+
langgraph~=0.4.5
|
7 |
+
loguru~=0.7.3
|
8 |
+
pydantic~=2.11.4
|