import os from textwrap import dedent from typing import TypedDict, Annotated, Optional, Any, Callable, Sequence, Union from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage from langchain_core.tools import BaseTool from langchain_openai import ChatOpenAI from langchain_tavily import TavilySearch from langgraph.constants import START from langgraph.errors import GraphRecursionError from langgraph.graph import add_messages, StateGraph from langgraph.prebuilt import ToolNode, tools_condition from langgraph.pregel import PregelProtocol from loguru import logger from pydantic import SecretStr from tools.excel_to_text import excel_to_text from tools.execute_python_code_from_file import execute_python_code_from_file from tools.maths import add_integers from tools.produce_classifier import produce_classifier from tools.sort_words_alphabetically import sort_words_alphabetically from tools.transcribe_audio import transcribe_audio from tools.web_page_information_extractor import web_page_information_extractor from tools.wikipedia_search import wikipedia_search from tools.youtube_transcript import youtube_transcript class AgentState(TypedDict): messages: Annotated[list[AnyMessage], add_messages] class ShrewdAgent: message_system = dedent(""" You are a general AI assistant equipped with a suite of external tools. Your task is to answer the following question as accurately and helpfully as possible by using the tools provided. Do not write or execute code yourself. For any operation requiring computation, data retrieval, or external access, explicitly invoke the appropriate tool. Follow these guidelines: - Clearly explain your reasoning step by step. - Justify your choice of tool(s) at each step. - If multiple interpretations are possible, outline them and explain your reasoning for selecting one. - If the answer requires external data or inference, retrieve or deduce it via the available tools. Important: Your final output MUST be only a number or a word with no additional text or explanation, unless the response format is explicitly specified in the question. Do not include reasoning, commentary, or any other content beyond the requested answer.""") def __init__(self): self.tools = [ TavilySearch(), wikipedia_search, web_page_information_extractor, youtube_transcript, produce_classifier, sort_words_alphabetically, excel_to_text, execute_python_code_from_file, add_integers, transcribe_audio, ] self.llm = ChatOpenAI( model="gpt-4.1", temperature=0, api_key=SecretStr(os.environ['OPENAI_API_KEY']) ).bind_tools(self.tools) def assistant_node(state: AgentState): return { "messages": [self.llm.invoke(state["messages"])], } self.agent = _build_state_graph(AgentState, assistant_node, self.tools) logger.info(f"Agent initialized with tools: {[tool.name for tool in self.tools]}") logger.debug(f"system message:\n{self.message_system}") def __call__(self, question: str) -> str: logger.info(f"Agent received question:\n{question}") accumulated_response = [] try: for chunk in self.agent.stream( {"messages": [ SystemMessage(self.message_system), HumanMessage(question, ) ]}, {"recursion_limit": 18}, ): assistant = chunk.get("assistant") if assistant: logger.debug(f"\n{assistant.get('messages')[0].pretty_repr()}") tools = chunk.get("tools") if tools: logger.debug(f"\n{tools.get('messages')[0].pretty_repr()}") accumulated_response.append(chunk) except GraphRecursionError as e: logger.error(f"GraphRecursionError: {e}") final_answer = "I couldn't find the answer" if accumulated_response[-1].get("assistant"): final_answer = accumulated_response[-1]["assistant"]['messages'][-1].content logger.info(f"Agent returning answer: {final_answer}") return final_answer def _build_state_graph( state_schema: Optional[type[Any]], assistant: Callable, tools: Sequence[Union[BaseTool, Callable]]) -> PregelProtocol: # CompiledStateGraph: return (StateGraph(state_schema) .add_node("assistant", assistant) .add_node("tools", ToolNode(tools)) .add_edge(START, "assistant") .add_conditional_edges("assistant", tools_condition) .add_edge("tools", "assistant") .compile() )