File size: 4,961 Bytes
281ee36
7e3b825
281ee36
 
 
 
 
5e2ccaa
281ee36
 
 
 
 
 
 
 
f3b4b95
64a08a8
28960ba
ee92f11
83001c2
4c07abc
fe989a0
508a421
eb08b08
fe989a0
281ee36
 
 
 
 
 
7e3b825
 
 
 
 
281ee36
7e3b825
 
 
 
 
9c4e9e4
 
 
 
281ee36
 
fe989a0
 
508a421
fe989a0
eb08b08
ee92f11
f3b4b95
 
64a08a8
28960ba
4c07abc
fe989a0
281ee36
331b5a4
281ee36
 
 
 
 
 
 
 
 
 
7e3b825
 
281ee36
 
28960ba
281ee36
 
 
 
 
 
 
d189934
281ee36
 
 
28960ba
281ee36
 
28960ba
281ee36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
from textwrap import dedent
from typing import TypedDict, Annotated, Optional, Any, Callable, Sequence, Union

from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langchain_core.tools import BaseTool
from langchain_openai import ChatOpenAI
from langchain_tavily import TavilySearch
from langgraph.constants import START
from langgraph.errors import GraphRecursionError
from langgraph.graph import add_messages, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.pregel import PregelProtocol
from loguru import logger
from pydantic import SecretStr

from tools.excel_to_text import excel_to_text
from tools.execute_python_code_from_file import execute_python_code_from_file
from tools.maths import add_integers
from tools.produce_classifier import produce_classifier
from tools.sort_words_alphabetically import sort_words_alphabetically
from tools.transcribe_audio import transcribe_audio
from tools.web_page_information_extractor import web_page_information_extractor
from tools.wikipedia_search import wikipedia_search
from tools.youtube_transcript import youtube_transcript


class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


class ShrewdAgent:
    message_system = dedent("""
        You are a general AI assistant equipped with a suite of external tools. Your task is to
        answer the following question as accurately and helpfully as possible by using the tools
        provided. Do not write or execute code yourself. For any operation requiring computation,
        data retrieval, or external access, explicitly invoke the appropriate tool.
        
        Follow these guidelines:
        - Clearly explain your reasoning step by step.
        - Justify your choice of tool(s) at each step.
        - If multiple interpretations are possible, outline them and explain your reasoning for selecting one.
        - If the answer requires external data or inference, retrieve or deduce it via the available tools.

        Important: Your final output MUST be only a number or a word with no additional text or explanation, 
        unless the response format is explicitly specified in the question. Do not include reasoning, 
        commentary, or any other content beyond the requested answer.""")

    def __init__(self):
        self.tools = [
            TavilySearch(),
            wikipedia_search,
            web_page_information_extractor,
            youtube_transcript,
            produce_classifier,
            sort_words_alphabetically,
            excel_to_text,
            execute_python_code_from_file,
            add_integers,
            transcribe_audio,
        ]
        self.llm = ChatOpenAI(
            model="gpt-4.1",
            temperature=0,
            api_key=SecretStr(os.environ['OPENAI_API_KEY'])
        ).bind_tools(self.tools)

        def assistant_node(state: AgentState):
            return {
                "messages": [self.llm.invoke(state["messages"])],
            }

        self.agent = _build_state_graph(AgentState, assistant_node, self.tools)
        logger.info(f"Agent initialized with tools: {[tool.name for tool in self.tools]}")
        logger.debug(f"system message:\n{self.message_system}")

    def __call__(self, question: str) -> str:
        logger.info(f"Agent received question:\n{question}")
        accumulated_response = []
        try:
            for chunk in self.agent.stream(
                    {"messages": [
                        SystemMessage(self.message_system),
                        HumanMessage(question, )
                    ]},
                    {"recursion_limit": 18},
            ):
                assistant = chunk.get("assistant")
                if assistant:
                    logger.debug(f"\n{assistant.get('messages')[0].pretty_repr()}")
                tools = chunk.get("tools")
                if tools:
                    logger.debug(f"\n{tools.get('messages')[0].pretty_repr()}")
                accumulated_response.append(chunk)

        except GraphRecursionError as e:
            logger.error(f"GraphRecursionError: {e}")

        final_answer = "I couldn't find the answer"
        if accumulated_response[-1].get("assistant"):
            final_answer = accumulated_response[-1]["assistant"]['messages'][-1].content
        logger.info(f"Agent returning answer: {final_answer}")
        return final_answer


def _build_state_graph(
        state_schema: Optional[type[Any]],
        assistant: Callable,
        tools: Sequence[Union[BaseTool, Callable]]) -> PregelProtocol:  # CompiledStateGraph:

    return (StateGraph(state_schema)
            .add_node("assistant", assistant)
            .add_node("tools", ToolNode(tools))
            .add_edge(START, "assistant")
            .add_conditional_edges("assistant", tools_condition)
            .add_edge("tools", "assistant")
            .compile()
            )