File size: 19,164 Bytes
a94fa9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# agent.py
import logging # Import logging
import os # For file/directory operations
import json # For reading/writing JSON answer files
# import base64 # No longer needed here
from typing import TypedDict, Annotated, Optional, List

from dotenv import load_dotenv # Import load_dotenv

from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition

from dataset_helper import download_file # For potential use in file handling


# Get the logger instance configured in app.py
logger = logging.getLogger("eval_logger")

# Load environment variables from .env file at the beginning
# This will load OPENAI_API_KEY if it's set in a .env file in the root directory.
if load_dotenv():
    logger.info(".env file loaded successfully by agent.py.")
else:
    logger.info(".env file not found or empty in agent.py, relying on system environment variables.")

# Import tools AFTER .env might have been loaded
from tools import TOOLS

# --- Agent State Definition ---
class AgentState(TypedDict):
    task_id: str
    original_question: str
    input_file_path: Optional[str] # Path to the locally downloaded file, if any
    messages: Annotated[list[AnyMessage], add_messages]
    # Potentially add other fields like 'scratchpad' or 'intermediate_steps' if needed

# --- Tool Definitions --- MOVED TO tools.py ---
# vision_llm, extract_text_from_image, search_tool, TOOLS list are now in tools.py

# --- LangGraph Agent Class ---
class LangGraphAgent:
    def __init__(self, api_url: str, answers_dir: str = "answers"):
        logger.info("LangGraphAgent initializing...")
        self.api_url = api_url # Needed for download_file, though not directly by graph
        self.answers_dir = answers_dir
        os.makedirs(self.answers_dir, exist_ok=True)
        logger.info(f"Answers will be stored in: {os.path.abspath(self.answers_dir)}")

        # Initialize LLM for the agent
        # Ensure OPENAI_API_KEY is set in your environment
        try:
            self.llm = ChatOpenAI(model="gpt-4o", temperature=0)
            # Bind tools imported from tools.py
            self.agent_llm = self.llm.bind_tools(TOOLS, parallel_tool_calls=False) # parallel_tool_calls=False as per example
        except Exception as e:
            logger.error(f"Failed to initialize agent LLM (ChatOpenAI with gpt-4o) or bind tools: {e}. Ensure OPENAI_API_KEY is set.", exc_info=True)
            self.llm = None
            self.agent_llm = None
        
        # Build the graph
        self.graph = self._build_graph()
        logger.info("LangGraphAgent initialized successfully.")

    def _save_answer(self, task_id: str, question: str, answer: str):
        """Saves the generated answer to a JSON file."""
        answer_payload = {"task_id": task_id, "question": question, "answer": answer}
        answer_file_path = os.path.join(self.answers_dir, f"{task_id}.json")
        try:
            with open(answer_file_path, 'w') as f:
                json.dump(answer_payload, f, indent=4)
            logger.info(f"Answer for task_id {task_id} saved to {answer_file_path}")
        except IOError as e:
            logger.error(f"Error saving answer for task_id {task_id} to {answer_file_path}: {e}", exc_info=True)

    def _load_answer(self, task_id: str) -> str | None:
        """Loads an answer from a JSON file if it exists."""
        answer_file_path = os.path.join(self.answers_dir, f"{task_id}.json")
        if os.path.exists(answer_file_path):
            try:
                with open(answer_file_path, 'r') as f:
                    answer_data = json.load(f)
                logger.info(f"Loaded existing answer for task_id {task_id} from {answer_file_path}")
                return answer_data.get("answer")
            except (IOError, json.JSONDecodeError) as e:
                logger.error(f"Error loading answer for task_id {task_id} from {answer_file_path}: {e}", exc_info=True)
        return None

    # --- Graph Node Definitions ---
    def _assistant_node(self, state: AgentState):
        logger.info(f"_assistant_node called for task_id: {state['task_id']}. Current messages count: {len(state['messages'])}")
        if not self.agent_llm:
            logger.error("Agent LLM not initialized. Cannot proceed with assistant node.")
            # Return a message indicating error, which will be added to state by add_messages
            # This helps in debugging and ensures flow continues to an extent
            error_message = SystemMessage(content="Error: Agent LLM not initialized. Cannot generate response.")
            return {"messages": [error_message]}

        system_prompt_parts = [
            f"You are a helpful AI assistant for the GAIA benchmark. Your goal is to answer the user's question accurately and concisely. ",
            f"The user's question is about task_id: {state['task_id']}.\n",
            f"The original question is: {state['original_question']}\n"
        ]
        
        input_file_path = state.get('input_file_path')
        original_question_text = state['original_question']

        if input_file_path:
            system_prompt_parts.append(f"A local file is available at path: {input_file_path}. ")
            file_extension = os.path.splitext(input_file_path)[1].lower()
            if file_extension in ['.png', '.jpg', '.jpeg', '.gif', '.webp']:
                system_prompt_parts.append(f"This file appears to be an image. You can use the 'analyse_image' tool to analyse it. This tool requires the 'img_path' (which is '{input_file_path}') and the 'question' (which is '{original_question_text}') to be passed as arguments. This tool works only for local image files. ")
            elif file_extension in ['.mp3', '.wav', '.aac', '.flac', '.ogg', '.opus']: # Common audio types for AssemblyAI
                system_prompt_parts.append(f"This file appears to be an audio file. You can use the 'analyse_audio' tool to analyse its content. This tool requires the 'audio_path' (which is '{input_file_path}') and the 'question' (which is '{original_question_text}') to be passed as arguments. This tool works only for local audio files and cannot process web URLs. ")
            elif file_extension == '.py':
                system_prompt_parts.append(f"This file appears to be a Python script. You can use the 'execute_python_code_from_file' tool to understand its content and answer questions about it (e.g., predict its output or describe its functionality). This tool requires the 'file_path' (which is '{input_file_path}') and the 'question' (which is '{original_question_text}') as arguments. This tool analyses the code textually; it does not execute it. ")
            elif file_extension in ['.xls', '.xlsx']:
                system_prompt_parts.append(f"This file appears to be an Excel file. To answer questions requiring calculations, data manipulation, or specific lookups: "
                                           f"1. You should generate a Python script using the pandas library. "
                                           f"2. Use the 'execute_pandas_script_for_excel' tool to run this script. "
                                           f"3. The script will have access to a variable 'excel_file_path' which holds the path: '{input_file_path}'. Use this variable in your script to load the Excel file (e.g., pd.read_excel(excel_file_path)). "
                                           f"4. Your generated Python script MUST end with a print() statement that outputs ONLY the final answer, precisely formatted. "
                                           f"5. If you first need to understand the structure of the Excel file (sheet names, columns), you can use the 'analyse_excel_file' tool which provides a textual (CSV) representation of the data. But for computation, use 'execute_pandas_script_for_excel'. "
                                           f"Pass the '{input_file_path}' as 'excel_file_path' and your generated script as 'python_code' to the 'execute_pandas_script_for_excel' tool. ")
            else:
                system_prompt_parts.append(f"The provided file '{input_file_path}' is not a supported image, audio, Python, or Excel type for direct analysis with available tools. Do not attempt to use 'analyse_image', 'analyse_audio', 'execute_python_code_from_file', or 'analyse_excel_file'/'execute_pandas_script_for_excel' for this file. You may need to rely on web search or the question text itself. ")
        else:
            system_prompt_parts.append("No local file was provided with this question. ")

        system_prompt_parts.append("If the question text itself contains a URL (e.g., a link to a YouTube video or other website), you should primarily use the 'web_search' tool to find information related to that URL and the question. For YouTube URLs, specifically rely on 'web_search' as direct transcript access is not available. ")
        system_prompt_parts.append("You also have access to a 'web_search' tool for general information or if the question implies online content (e.g., a URL mentioned in the question text). ")
        system_prompt_parts.append("If a tool fails or a file type is unsupported, do not try the same tool repeatedly on it. Use web_search or state you cannot answer if appropriate. ")
        system_prompt_parts.append("Prioritize answering the question. If after about 5-7 tool execution cycles you cannot find a definitive answer, you MUST provide the best possible answer based on the information you have gathered or state CLEARLY that you cannot answer the question. DO NOT get stuck in overly long loops of tool use. Be decisive and conclude your reasoning.")
        system_prompt_parts.append("When providing your final answer, it is crucial that it is ONLY the answer itself, with absolutely no additional conversation, explanations, or formatting like 'The answer is...' or 'Based on my findings...'. Be direct. ")
        system_prompt_parts.append("The final answer format must be one of the following: ")
        system_prompt_parts.append("1. A number (e.g., 42, 1000, 3.14). Do not use commas for thousands separators (e.g., write 1000 not 1,000). Do not use units like '$' or '%' unless the question explicitly asks for it in the answer format. ")
        system_prompt_parts.append("2. As few words as possible (e.g., 'Paris', 'Mount Everest'). Do not use articles (a, an, the) unless part of a proper name. Avoid abbreviations (e.g., use 'Los Angeles' not 'LA') unless the question implies it. Write digits in plain text (e.g., 'two' instead of '2') unless the question asks for a numerical digit. ")
        system_prompt_parts.append("3. A comma-separated list of numbers and/or strings (e.g., 'red,blue,green', '1,2,three', 'Tokyo,London,New York'). Apply the rules from 1 and 2 to each element in the list. Ensure there are no spaces after commas unless a list element itself naturally contains a space (e.g. a multi-word city name). ")
        system_prompt_parts.append("Adhere to these formatting rules strictly for the final output.")
        system_prompt_parts.append("You also have access to a 'wikipedia_tool' to get information from Wikipedia. It's good for general knowledge questions, facts, definitions, and summaries on a wide range of topics.")
        system_prompt_parts.append("For questions specifically about the visual content of a YouTube video, use the 'analyse_youtube' tool. Provide the 'youtube_url' and the 'question'. This tool uses a Gemini multimodal model. If this tool fails or cannot answer, you can fall back to 'web_search' for general information about the video.")
        system_prompt_parts.append("If you encounter a particularly complex question (e.g., historical queries with multiple constraints, or questions requiring deep, multi-step reasoning) and you are struggling to find a definitive answer after attempting with standard tools (like web_search, wikipedia_tool) for a few cycles (e.g., 2-3 attempts), you can use the 'deep_analysis_with_gemini' tool. Pass the original, full question to this tool. Use this as a strategic escalation for very challenging textual questions.")
        system_prompt_parts.append("If a tool fails or a file type is unsupported, do not try the same tool repeatedly on it. Use web_search or state you cannot answer if appropriate. ")
        
        system_prompt = "".join(system_prompt_parts)

        messages_for_llm = [SystemMessage(content=system_prompt)] + state["messages"]
        
        logger.debug(f"Messages being sent to LLM for task {state['task_id']}: {messages_for_llm}")
        response_message = self.agent_llm.invoke(messages_for_llm)
        logger.debug(f"LLM response for task {state['task_id']}: {response_message}")
        return {"messages": [response_message]} # LangGraph's add_messages will append this

    def _build_graph(self) -> StateGraph:
        logger.info("Building LangGraph...")
        builder = StateGraph(AgentState)
        builder.add_node("assistant", self._assistant_node)
        tool_node = ToolNode(TOOLS) # Create a ToolNode with all our tools
        builder.add_node("tools", tool_node)

        builder.add_edge(START, "assistant")
        builder.add_conditional_edges(
            "assistant",
            tools_condition, # LangGraph's prebuilt tools_condition
            # END # If no tool call, end. (Modified below to ensure final processing)
        )
        # builder.add_edge("tools", "assistant") # Loop back from tools to assistant
        
        # Modified flow: Tools execute, then always go back to assistant for summarization/final answer
        # If assistant decided no tool, tools_condition might route to END if not handled
        # We want the assistant to make the final decision to END.
        
        # If assistant calls a tool, route to tools.
        # If assistant does not call a tool, it should be the final answer.
        # tools_condition will route to END if no tool calls are present in the AI message.
        # So, if tools_condition routes to END, it means the assistant provided the final answer.

        builder.add_edge("tools", "assistant") # Always go back to assistant after a tool run

        # graph = builder.compile(checkpointer=None, recursion_limit=35) # Incorrect parameter
        graph = builder.compile(checkpointer=None) # Corrected: remove recursion_limit
        logger.info("LangGraph built successfully.")
        # try:
        #     # For debugging: display graph structure if possible (requires graphviz)
        #     # from IPython.display import Image, display
        #     # display(Image(graph.get_graph(xray=True).draw_mermaid_png()))
        #     logger.info("Graph visualization (mermaid PNG) can be generated if IPython and graphviz are available.")
        # except Exception as e:
        #     logger.warning(f"Could not generate graph visualization: {e}")
        return graph

    def __call__(self, task_id: str, question: str, file_name: str | None) -> tuple[str, bool]:
        logger.info(f"LangGraphAgent __call__ for task_id: {task_id}")

        # 1. Check for cached answer first
        cached_answer = self._load_answer(task_id)
        if cached_answer is not None:
            logger.info(f"Returning cached answer for {task_id}.")
            return cached_answer, True

        if not self.graph or not self.agent_llm:
            logger.error("Agent graph or LLM not initialized. Cannot process question.")
            return "Error: Agent not properly initialized.", False

        # 2. Download file if provided
        local_file_path = None
        if file_name:
            logger.info(f"Associated file '{file_name}' for task {task_id}. Attempting download.")
            local_file_path = download_file(self.api_url, task_id, file_name, download_dir="downloads") # Ensure 'downloads' dir
            if local_file_path:
                logger.info(f"File '{file_name}' available at {local_file_path} for task {task_id}.")
            else:
                logger.error(f"Failed to download file '{file_name}' for task {task_id}.")
                # Agent might still try to answer or this could be a hard failure depending on the question

        # 3. Invoke the graph
        initial_state: AgentState = {
            "task_id": task_id,
            "original_question": question,
            "input_file_path": local_file_path,
            "messages": [HumanMessage(content=question)]
        }
        
        final_answer_content = f"Error: Agent did not produce a final answer for task {task_id}." # Default error
        try:
            logger.info(f"Invoking graph for task_id: {task_id} with initial state.")
            # Stream events for debugging if needed:
            # for event in self.graph.stream(initial_state, stream_mode="values"):
            #     logger.debug(f"Graph event for {task_id}: {event}")
            #     final_state = event

            final_state = self.graph.invoke(initial_state, config={'recursion_limit': 50}) # Increased to 50
            logger.info(f"Graph invocation complete for task_id: {task_id}.")

            if final_state and final_state.get("messages"):
                # The final answer should be the content of the last AI message that is not a tool call
                for msg in reversed(final_state["messages"]):
                    if msg.type == "ai" and not msg.tool_calls: # Check for AI message without tool calls
                        final_answer_content = msg.content
                        logger.info(f"Extracted final answer for {task_id}: '{final_answer_content[:100]}...' ")
                        break
                    elif msg.type == "system" and "Error: Agent LLM not initialized" in msg.content: # Check for our specific error
                        final_answer_content = msg.content
                        break
                else: # If loop finishes without break (no suitable AI message found)
                    logger.warning(f"No suitable final AI message found for task {task_id}. Last messages: {final_state.get('messages')}")
                    # Fallback or specific error message.
                    # For now, use the last message content if any, or keep the default error.
                    if final_state.get("messages"):
                        final_answer_content = final_state["messages"][-1].content # Best guess
            else:
                logger.error(f"Graph did not return messages in final_state for task {task_id}. Final state: {final_state}")

        except Exception as e:
            logger.error(f"Error during LangGraph agent execution for task_id {task_id}: {e}", exc_info=True)
            final_answer_content = f"Error during agent execution: {str(e)}"

        # 4. Save and return the final answer
        self._save_answer(task_id, question, final_answer_content)
        return final_answer_content, False # False because it's newly generated/processed by graph