Spaces:
Sleeping
Sleeping
# env variable needed: HF_TOKEN, OPENAI_API_KEY, BRAVE_SEARCH_API_KEY | |
import os | |
import json | |
from typing import Literal | |
from langchain_openai import ChatOpenAI | |
from langgraph.graph import MessagesState | |
from langchain_core.messages import SystemMessage, HumanMessage, ToolMessage | |
from langgraph.graph import StateGraph, START, END | |
from langchain_community.tools import DuckDuckGoSearchResults, BraveSearch | |
from custom_tools import * | |
class LangGraphAgent: | |
def __init__(self, | |
model_name="o4-mini", | |
show_tools_desc=True, | |
show_prompt=True): | |
with open('system_prompt.txt', 'r') as file: | |
system_prompt = file.read() | |
# =========== LLM definition =========== | |
if model_name.startswith('o'): | |
# reasoning model (no temperature setting) | |
llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model=model_name) # needs OPENAI_API_KEY in env | |
else: | |
llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"),model=model_name, temperature=0) | |
print(f"LangGraphAgent initialized with model \"{model_name}\"") | |
# =========== Augment the LLM with tools =========== | |
community_tools = [ | |
BraveSearch.from_api_key( # Web search (more performant than DuckDuckGo) | |
api_key=os.getenv("BRAVE_API_KEY"), # needs BRAVE_SEARCH_API_KEY in env | |
search_kwargs={"count": 5}), | |
] | |
custom_tools = [ | |
multiply, add, subtract, divide, modulus, power, # Basic arithmetic | |
query_image, # Ask anything about an image using a VLM | |
automatic_speech_recognition, # Transcribe an audio file to text | |
get_webpage_content, # Load a web page and return it to markdown | |
python_repl_tool, # Python code interpreter | |
get_youtube_transcript, # Get the transcript of a YouTube video | |
wikipedia_search, # Search Wikipedia and return the top-k page extracts (intro only) as markdown | |
html_table_query, # Fetch an HTML table and return it as JSON or markdown | |
spreadsheet_tool | |
] | |
tools = community_tools + custom_tools | |
tools_by_name = {tool.name: tool for tool in tools} | |
llm_with_tools = llm.bind_tools(tools) | |
# =========== Agent definition =========== | |
# Nodes | |
def llm_call(state: MessagesState): | |
"""LLM decides whether to call a tool or not""" | |
return { | |
"messages": [ | |
llm_with_tools.invoke( | |
[ | |
SystemMessage( | |
content=system_prompt | |
) | |
] | |
+ state["messages"] | |
) | |
] | |
} | |
def tool_node(state: dict): | |
"""Performs the tool call""" | |
result = [] | |
for tool_call in state["messages"][-1].tool_calls: | |
tool = tools_by_name[tool_call["name"]] | |
observation = tool.invoke(tool_call["args"]) | |
result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"])) | |
return {"messages": result} | |
# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call | |
def should_continue(state: MessagesState) -> Literal["environment", END]: | |
"""Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" | |
messages = state["messages"] | |
last_message = messages[-1] | |
# If the LLM makes a tool call, then perform an action | |
if last_message.tool_calls: | |
return "Action" | |
# Otherwise, we stop (reply to the user) | |
return END | |
# Build workflow | |
agent_builder = StateGraph(MessagesState) | |
# Add nodes | |
agent_builder.add_node("llm_call", llm_call) | |
agent_builder.add_node("environment", tool_node) | |
# Add edges to connect nodes | |
agent_builder.add_edge(START, "llm_call") | |
agent_builder.add_conditional_edges( | |
"llm_call", | |
should_continue, | |
{ | |
# Name returned by should_continue : Name of next node to visit | |
"Action": "environment", | |
END: END, | |
}, | |
) | |
agent_builder.add_edge("environment", "llm_call") | |
# Compile the agent | |
self.agent = agent_builder.compile() | |
if show_tools_desc: | |
for i, tool in enumerate(llm_with_tools.kwargs['tools']): | |
print("\n" + "="*30 + f" Tool {i+1} " + "="*30) | |
print(json.dumps(tool[tool['type']], indent=4)) | |
if show_prompt: | |
print("\n" + "="*30 + f" System prompt " + "="*30) | |
print(system_prompt) | |
def __call__(self, question: str) -> str: | |
print("\n\n"+"*"*50) | |
print(f"Agent received question: {question}") | |
print("*"*50) | |
# Invoke | |
messages = [HumanMessage(content=question)] | |
messages = self.agent.invoke({"messages": messages}, | |
{"recursion_limit": 30}) # maximum number of steps before hitting a stop condition | |
for m in messages["messages"]: | |
m.pretty_print() | |
# post-process the response (keep only what's after "FINAL ANSWER:" for the exact match) | |
response = str(messages["messages"][-1].content) | |
try: | |
response = response.split("FINAL ANSWER:")[-1].strip() | |
except: | |
print('Could not split response on "FINAL ANSWER:"') | |
print("\n\n"+"-"*50) | |
print(f"Agent returning with answer: {response}") | |
return response | |
if __name__ == "__main__": | |
agent = LangGraphAgent() | |
agent("Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation. File url: https://agents-course-unit4-scoring.hf.space/files/cca530fc-4052-43b2-b130-b30968d8aa44'" | |
) |