super_agent / agent.py
lezaf
Perform some cleanup
799013a
import os
import requests
from dotenv import load_dotenv
from langgraph.graph import StateGraph, MessagesState, START
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_google_genai import ChatGoogleGenerativeAI
from langfuse.langchain import CallbackHandler
from tools.web_search import web_search
from tools.math import add_numbers_in_list, check_commutativity
from tools.extraction import extract_data_from_excel, extract_transcript_from_youtube, extract_transcript_from_audio
from rate_limiters import safe_invoke_with_retry_gemini
load_dotenv(override=True)
PROVIDER="google"
langfuse_handler = CallbackHandler()
tools = [
add_numbers_in_list,
web_search,
check_commutativity,
extract_data_from_excel,
extract_transcript_from_youtube,
extract_transcript_from_audio
]
# --------------- Define the agent structure ---------------- #
def build_agent(provider: str = "hf"):
USE_RATE_LIMITER = os.getenv("USE_RATE_LIMITER", "false").lower() == "true"
print(f"Building agent with provider: {provider}")
if provider == "hf":
llm = HuggingFaceEndpoint(
repo_id="mistralai/Mixtral-8x7B-Instruct-v0.1",
task="text-generation",
temperature=0.0,
provider="hf-inference"
)
llm = ChatHuggingFace(llm=llm)
elif provider == "google":
# Google Gemini
llm = ChatGoogleGenerativeAI(
# model="gemini-2.0-flash",
model="gemini-2.5-flash-preview-05-20",
# model="gemini-2.5-flash-lite-preview-06-17",
max_tokens=2048,
max_retries=2,
)
elif provider == "openai":
llm = ChatOpenAI(
model="gpt-3.5-turbo", # or "gpt-3.5-turbo"
temperature=0,
api_key=os.getenv("OPENAI_API_KEY"),
max_tokens=512
)
else:
raise ValueError(f"Unsupported provider: {provider}")
# Bind the tools to the LLM
llm_with_tools = llm.bind_tools(tools)
# load the system prompt from the file
with open("system_prompt.txt", "r", encoding="utf-8") as f:
system_prompt = f.read()
# Create system message with the system prompt
sys_msg = SystemMessage(content=system_prompt)
# --------------- Define nodes ---------------- #
def assistant(state: MessagesState):
"""Node for the assistant to respond to user input."""
if USE_RATE_LIMITER:
if provider == "google":
response = safe_invoke_with_retry_gemini(
llm_with_tools,
[sys_msg] + state["messages"],
max_retries=2,
wait_seconds=60
)
else:
raise ValueError(f"Rate limiting is not implemented for provider {provider}.")
else:
response = llm_with_tools.invoke([sys_msg] + state["messages"])
return {"messages": [response]}
tool_node = ToolNode(tools=tools)
# --------------- Build the state graph ---------------- #
graph_builder = StateGraph(MessagesState)
graph_builder.add_node("assistant", assistant)
graph_builder.add_node("tools", tool_node)
graph_builder.add_conditional_edges(
"assistant",
tools_condition,
)
graph_builder.add_edge("tools", "assistant")
graph_builder.add_edge(START, "assistant")
return graph_builder.compile()
# --------------- For manual testing ---------------- #
if __name__ == "__main__":
print("\n" + "-"*30 + " Agent Starting " + "-"*30)
# Print run variables in a table format
print(f"Provider: {PROVIDER}")
print(f"Search engine used: {'DDGS' if os.getenv('USE_DDGS').lower() == 'true' else 'Tavily'}")
agent = build_agent(provider=PROVIDER) # Change to "hf" for HuggingFace
print("Agent built successfully.")
print("-"*70)
# Get questions
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
api_url = DEFAULT_API_URL
questions_url = f"{api_url}/questions"
files_url = f"{api_url}/files/" # Needs task_id
# 2. Fetch Questions
print(f"Fetching questions from: {questions_url}")
try:
response = requests.get(questions_url, timeout=15)
response.raise_for_status()
questions_data = response.json()
if not questions_data:
print("Fetched questions list is empty.")
print(f"Fetched {len(questions_data)} questions.")
except Exception as e:
print(f"An unexpected error occurred fetching questions: {e}")
# 3. Get specific question by task_id
# task_id = "8e867cd7-cff9-4e6c-867a-ff5ddc2550be" # Sosa albums
# task_id = "2d83110e-a098-4ebb-9987-066c06fa42d0" # Reverse text example
# task_id = "cca530fc-4052-43b2-b130-b30968d8aa44" # Chess image
# task_id = "4fc2f1ae-8625-45b5-ab34-ad4433bc21f8" # Dinosaur ?
# task_id = "6f37996b-2ac7-44b0-8e68-6d28256631b4" # Commutativity check
# task_id = "9d191bce-651d-4746-be2d-7ef8ecadb9c2" # Youtube video
# task_id = "cabe07ed-9eca-40ea-8ead-410ef5e83f91" # Louvrier ?
# task_id = "f918266a-b3e0-4914-865d-4faa564f1aef" # Code example
# task_id = "3f57289b-8c60-48be-bd80-01f8099ca449" # at bats ?
task_id = "7bd855d8-463d-4ed5-93ca-5fe35145f733" # Excel file
# task_id = "5a0c1adf-205e-4841-a666-7c3ef95def9d" # Malko competition
# task_id = "305ac316-eef6-4446-960a-92d80d542f82" # Poland film
# task_id = "bda648d7-d618-4883-88f4-3466eabd860e" # Vietnamese
# task_id = "cf106601-ab4f-4af9-b045-5295fe67b37d" # Olympics
# task_id = "a0c07678-e491-4bbc-8f0b-07405144218f" # pitchers
# task_id = "3cef3a44-215e-4aed-8e3b-b1e3f08063b7" # grocery list
# task_id = "840bfca7-4f7b-481a-8794-c560c340185d" # Carolyn Collins Petersen
# task_id = "1f975693-876d-457b-a649-393859e79bf3" # Audio (pages)
# task_id = "99c9cc74-fdc8-46c6-8f8d-3ce2d3bfeea3" # Audio (recipe)
# get question with task_id
q_data = next((item for item in questions_data if item["task_id"] == task_id), None)
content = [
{"type": "text", "text": q_data["question"]}
]
if q_data["file_name"] != "":
file_url = f"{files_url}{task_id}"
if q_data["file_name"].endswith((".png", ".jpg", ".jpeg")):
content.append({"type": "image_url", "image_url": {"url": file_url}})
elif q_data["file_name"].endswith((".py")):
# For code files, we can just send the text content
try:
response = requests.get(file_url, timeout=15)
response.raise_for_status()
code_content = response.text
content.append({"type": "text", "text": code_content})
except Exception as e:
print(f"Error fetching code file: {e}")
elif q_data["file_name"].endswith((".xlsx", ".xls")):
content.append({"type": "text", "text": "Excel file url: " + file_url})
elif q_data["file_name"].endswith((".mp3", ".wav")):
content.append({"type": "text", "text": "Audio file url: " + file_url})
else:
content.append({"type": "text", "text": f"File URL: {file_url} (file type not supported)"})
human_msg = HumanMessage(content=content)
human_msg.pretty_print()
try:
result = agent.invoke(
{"messages": [human_msg]},
config={"callbacks": [langfuse_handler]}
)
for message in result["messages"]:
message.pretty_print()
# Result already printed inside assistant() node
except Exception as e:
print(f"Error: {e}")