priyamarwaha's picture
Upload 30 files
a94fa9b verified
# agent.py
import logging # Import logging
import os # For file/directory operations
import json # For reading/writing JSON answer files
# import base64 # No longer needed here
from typing import TypedDict, Annotated, Optional, List
from dotenv import load_dotenv # Import load_dotenv
from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from dataset_helper import download_file # For potential use in file handling
# Get the logger instance configured in app.py
logger = logging.getLogger("eval_logger")
# Load environment variables from .env file at the beginning
# This will load OPENAI_API_KEY if it's set in a .env file in the root directory.
if load_dotenv():
logger.info(".env file loaded successfully by agent.py.")
else:
logger.info(".env file not found or empty in agent.py, relying on system environment variables.")
# Import tools AFTER .env might have been loaded
from tools import TOOLS
# --- Agent State Definition ---
class AgentState(TypedDict):
task_id: str
original_question: str
input_file_path: Optional[str] # Path to the locally downloaded file, if any
messages: Annotated[list[AnyMessage], add_messages]
# Potentially add other fields like 'scratchpad' or 'intermediate_steps' if needed
# --- Tool Definitions --- MOVED TO tools.py ---
# vision_llm, extract_text_from_image, search_tool, TOOLS list are now in tools.py
# --- LangGraph Agent Class ---
class LangGraphAgent:
def __init__(self, api_url: str, answers_dir: str = "answers"):
logger.info("LangGraphAgent initializing...")
self.api_url = api_url # Needed for download_file, though not directly by graph
self.answers_dir = answers_dir
os.makedirs(self.answers_dir, exist_ok=True)
logger.info(f"Answers will be stored in: {os.path.abspath(self.answers_dir)}")
# Initialize LLM for the agent
# Ensure OPENAI_API_KEY is set in your environment
try:
self.llm = ChatOpenAI(model="gpt-4o", temperature=0)
# Bind tools imported from tools.py
self.agent_llm = self.llm.bind_tools(TOOLS, parallel_tool_calls=False) # parallel_tool_calls=False as per example
except Exception as e:
logger.error(f"Failed to initialize agent LLM (ChatOpenAI with gpt-4o) or bind tools: {e}. Ensure OPENAI_API_KEY is set.", exc_info=True)
self.llm = None
self.agent_llm = None
# Build the graph
self.graph = self._build_graph()
logger.info("LangGraphAgent initialized successfully.")
def _save_answer(self, task_id: str, question: str, answer: str):
"""Saves the generated answer to a JSON file."""
answer_payload = {"task_id": task_id, "question": question, "answer": answer}
answer_file_path = os.path.join(self.answers_dir, f"{task_id}.json")
try:
with open(answer_file_path, 'w') as f:
json.dump(answer_payload, f, indent=4)
logger.info(f"Answer for task_id {task_id} saved to {answer_file_path}")
except IOError as e:
logger.error(f"Error saving answer for task_id {task_id} to {answer_file_path}: {e}", exc_info=True)
def _load_answer(self, task_id: str) -> str | None:
"""Loads an answer from a JSON file if it exists."""
answer_file_path = os.path.join(self.answers_dir, f"{task_id}.json")
if os.path.exists(answer_file_path):
try:
with open(answer_file_path, 'r') as f:
answer_data = json.load(f)
logger.info(f"Loaded existing answer for task_id {task_id} from {answer_file_path}")
return answer_data.get("answer")
except (IOError, json.JSONDecodeError) as e:
logger.error(f"Error loading answer for task_id {task_id} from {answer_file_path}: {e}", exc_info=True)
return None
# --- Graph Node Definitions ---
def _assistant_node(self, state: AgentState):
logger.info(f"_assistant_node called for task_id: {state['task_id']}. Current messages count: {len(state['messages'])}")
if not self.agent_llm:
logger.error("Agent LLM not initialized. Cannot proceed with assistant node.")
# Return a message indicating error, which will be added to state by add_messages
# This helps in debugging and ensures flow continues to an extent
error_message = SystemMessage(content="Error: Agent LLM not initialized. Cannot generate response.")
return {"messages": [error_message]}
system_prompt_parts = [
f"You are a helpful AI assistant for the GAIA benchmark. Your goal is to answer the user's question accurately and concisely. ",
f"The user's question is about task_id: {state['task_id']}.\n",
f"The original question is: {state['original_question']}\n"
]
input_file_path = state.get('input_file_path')
original_question_text = state['original_question']
if input_file_path:
system_prompt_parts.append(f"A local file is available at path: {input_file_path}. ")
file_extension = os.path.splitext(input_file_path)[1].lower()
if file_extension in ['.png', '.jpg', '.jpeg', '.gif', '.webp']:
system_prompt_parts.append(f"This file appears to be an image. You can use the 'analyse_image' tool to analyse it. This tool requires the 'img_path' (which is '{input_file_path}') and the 'question' (which is '{original_question_text}') to be passed as arguments. This tool works only for local image files. ")
elif file_extension in ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.opus']: # Common audio types for AssemblyAI
system_prompt_parts.append(f"This file appears to be an audio file. You can use the 'analyse_audio' tool to analyse its content. This tool requires the 'audio_path' (which is '{input_file_path}') and the 'question' (which is '{original_question_text}') to be passed as arguments. This tool works only for local audio files and cannot process web URLs. ")
elif file_extension == '.py':
system_prompt_parts.append(f"This file appears to be a Python script. You can use the 'execute_python_code_from_file' tool to understand its content and answer questions about it (e.g., predict its output or describe its functionality). This tool requires the 'file_path' (which is '{input_file_path}') and the 'question' (which is '{original_question_text}') as arguments. This tool analyses the code textually; it does not execute it. ")
elif file_extension in ['.xls', '.xlsx']:
system_prompt_parts.append(f"This file appears to be an Excel file. To answer questions requiring calculations, data manipulation, or specific lookups: "
f"1. You should generate a Python script using the pandas library. "
f"2. Use the 'execute_pandas_script_for_excel' tool to run this script. "
f"3. The script will have access to a variable 'excel_file_path' which holds the path: '{input_file_path}'. Use this variable in your script to load the Excel file (e.g., pd.read_excel(excel_file_path)). "
f"4. Your generated Python script MUST end with a print() statement that outputs ONLY the final answer, precisely formatted. "
f"5. If you first need to understand the structure of the Excel file (sheet names, columns), you can use the 'analyse_excel_file' tool which provides a textual (CSV) representation of the data. But for computation, use 'execute_pandas_script_for_excel'. "
f"Pass the '{input_file_path}' as 'excel_file_path' and your generated script as 'python_code' to the 'execute_pandas_script_for_excel' tool. ")
else:
system_prompt_parts.append(f"The provided file '{input_file_path}' is not a supported image, audio, Python, or Excel type for direct analysis with available tools. Do not attempt to use 'analyse_image', 'analyse_audio', 'execute_python_code_from_file', or 'analyse_excel_file'/'execute_pandas_script_for_excel' for this file. You may need to rely on web search or the question text itself. ")
else:
system_prompt_parts.append("No local file was provided with this question. ")
system_prompt_parts.append("If the question text itself contains a URL (e.g., a link to a YouTube video or other website), you should primarily use the 'web_search' tool to find information related to that URL and the question. For YouTube URLs, specifically rely on 'web_search' as direct transcript access is not available. ")
system_prompt_parts.append("You also have access to a 'web_search' tool for general information or if the question implies online content (e.g., a URL mentioned in the question text). ")
system_prompt_parts.append("If a tool fails or a file type is unsupported, do not try the same tool repeatedly on it. Use web_search or state you cannot answer if appropriate. ")
system_prompt_parts.append("Prioritize answering the question. If after about 5-7 tool execution cycles you cannot find a definitive answer, you MUST provide the best possible answer based on the information you have gathered or state CLEARLY that you cannot answer the question. DO NOT get stuck in overly long loops of tool use. Be decisive and conclude your reasoning.")
system_prompt_parts.append("When providing your final answer, it is crucial that it is ONLY the answer itself, with absolutely no additional conversation, explanations, or formatting like 'The answer is...' or 'Based on my findings...'. Be direct. ")
system_prompt_parts.append("The final answer format must be one of the following: ")
system_prompt_parts.append("1. A number (e.g., 42, 1000, 3.14). Do not use commas for thousands separators (e.g., write 1000 not 1,000). Do not use units like '$' or '%' unless the question explicitly asks for it in the answer format. ")
system_prompt_parts.append("2. As few words as possible (e.g., 'Paris', 'Mount Everest'). Do not use articles (a, an, the) unless part of a proper name. Avoid abbreviations (e.g., use 'Los Angeles' not 'LA') unless the question implies it. Write digits in plain text (e.g., 'two' instead of '2') unless the question asks for a numerical digit. ")
system_prompt_parts.append("3. A comma-separated list of numbers and/or strings (e.g., 'red,blue,green', '1,2,three', 'Tokyo,London,New York'). Apply the rules from 1 and 2 to each element in the list. Ensure there are no spaces after commas unless a list element itself naturally contains a space (e.g. a multi-word city name). ")
system_prompt_parts.append("Adhere to these formatting rules strictly for the final output.")
system_prompt_parts.append("You also have access to a 'wikipedia_tool' to get information from Wikipedia. It's good for general knowledge questions, facts, definitions, and summaries on a wide range of topics.")
system_prompt_parts.append("For questions specifically about the visual content of a YouTube video, use the 'analyse_youtube' tool. Provide the 'youtube_url' and the 'question'. This tool uses a Gemini multimodal model. If this tool fails or cannot answer, you can fall back to 'web_search' for general information about the video.")
system_prompt_parts.append("If you encounter a particularly complex question (e.g., historical queries with multiple constraints, or questions requiring deep, multi-step reasoning) and you are struggling to find a definitive answer after attempting with standard tools (like web_search, wikipedia_tool) for a few cycles (e.g., 2-3 attempts), you can use the 'deep_analysis_with_gemini' tool. Pass the original, full question to this tool. Use this as a strategic escalation for very challenging textual questions.")
system_prompt_parts.append("If a tool fails or a file type is unsupported, do not try the same tool repeatedly on it. Use web_search or state you cannot answer if appropriate. ")
system_prompt = "".join(system_prompt_parts)
messages_for_llm = [SystemMessage(content=system_prompt)] + state["messages"]
logger.debug(f"Messages being sent to LLM for task {state['task_id']}: {messages_for_llm}")
response_message = self.agent_llm.invoke(messages_for_llm)
logger.debug(f"LLM response for task {state['task_id']}: {response_message}")
return {"messages": [response_message]} # LangGraph's add_messages will append this
def _build_graph(self) -> StateGraph:
logger.info("Building LangGraph...")
builder = StateGraph(AgentState)
builder.add_node("assistant", self._assistant_node)
tool_node = ToolNode(TOOLS) # Create a ToolNode with all our tools
builder.add_node("tools", tool_node)
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
tools_condition, # LangGraph's prebuilt tools_condition
# END # If no tool call, end. (Modified below to ensure final processing)
)
# builder.add_edge("tools", "assistant") # Loop back from tools to assistant
# Modified flow: Tools execute, then always go back to assistant for summarization/final answer
# If assistant decided no tool, tools_condition might route to END if not handled
# We want the assistant to make the final decision to END.
# If assistant calls a tool, route to tools.
# If assistant does not call a tool, it should be the final answer.
# tools_condition will route to END if no tool calls are present in the AI message.
# So, if tools_condition routes to END, it means the assistant provided the final answer.
builder.add_edge("tools", "assistant") # Always go back to assistant after a tool run
# graph = builder.compile(checkpointer=None, recursion_limit=35) # Incorrect parameter
graph = builder.compile(checkpointer=None) # Corrected: remove recursion_limit
logger.info("LangGraph built successfully.")
# try:
# # For debugging: display graph structure if possible (requires graphviz)
# # from IPython.display import Image, display
# # display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
# logger.info("Graph visualization (mermaid PNG) can be generated if IPython and graphviz are available.")
# except Exception as e:
# logger.warning(f"Could not generate graph visualization: {e}")
return graph
def __call__(self, task_id: str, question: str, file_name: str | None) -> tuple[str, bool]:
logger.info(f"LangGraphAgent __call__ for task_id: {task_id}")
# 1. Check for cached answer first
cached_answer = self._load_answer(task_id)
if cached_answer is not None:
logger.info(f"Returning cached answer for {task_id}.")
return cached_answer, True
if not self.graph or not self.agent_llm:
logger.error("Agent graph or LLM not initialized. Cannot process question.")
return "Error: Agent not properly initialized.", False
# 2. Download file if provided
local_file_path = None
if file_name:
logger.info(f"Associated file '{file_name}' for task {task_id}. Attempting download.")
local_file_path = download_file(self.api_url, task_id, file_name, download_dir="downloads") # Ensure 'downloads' dir
if local_file_path:
logger.info(f"File '{file_name}' available at {local_file_path} for task {task_id}.")
else:
logger.error(f"Failed to download file '{file_name}' for task {task_id}.")
# Agent might still try to answer or this could be a hard failure depending on the question
# 3. Invoke the graph
initial_state: AgentState = {
"task_id": task_id,
"original_question": question,
"input_file_path": local_file_path,
"messages": [HumanMessage(content=question)]
}
final_answer_content = f"Error: Agent did not produce a final answer for task {task_id}." # Default error
try:
logger.info(f"Invoking graph for task_id: {task_id} with initial state.")
# Stream events for debugging if needed:
# for event in self.graph.stream(initial_state, stream_mode="values"):
# logger.debug(f"Graph event for {task_id}: {event}")
# final_state = event
final_state = self.graph.invoke(initial_state, config={'recursion_limit': 50}) # Increased to 50
logger.info(f"Graph invocation complete for task_id: {task_id}.")
if final_state and final_state.get("messages"):
# The final answer should be the content of the last AI message that is not a tool call
for msg in reversed(final_state["messages"]):
if msg.type == "ai" and not msg.tool_calls: # Check for AI message without tool calls
final_answer_content = msg.content
logger.info(f"Extracted final answer for {task_id}: '{final_answer_content[:100]}...' ")
break
elif msg.type == "system" and "Error: Agent LLM not initialized" in msg.content: # Check for our specific error
final_answer_content = msg.content
break
else: # If loop finishes without break (no suitable AI message found)
logger.warning(f"No suitable final AI message found for task {task_id}. Last messages: {final_state.get('messages')}")
# Fallback or specific error message.
# For now, use the last message content if any, or keep the default error.
if final_state.get("messages"):
final_answer_content = final_state["messages"][-1].content # Best guess
else:
logger.error(f"Graph did not return messages in final_state for task {task_id}. Final state: {final_state}")
except Exception as e:
logger.error(f"Error during LangGraph agent execution for task_id {task_id}: {e}", exc_info=True)
final_answer_content = f"Error during agent execution: {str(e)}"
# 4. Save and return the final answer
self._save_answer(task_id, question, final_answer_content)
return final_answer_content, False # False because it's newly generated/processed by graph