File size: 14,211 Bytes
5d520be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb455c
 
5d520be
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
 
 
cbb455c
 
 
 
 
 
 
 
 
5d520be
 
cbb455c
 
5d520be
 
 
 
 
 
 
cbb455c
5d520be
 
 
cbb455c
5d520be
 
 
 
cbb455c
 
5d520be
 
 
cbb455c
 
5d520be
 
cbb455c
 
5d520be
cbb455c
5d520be
cbb455c
 
5d520be
 
 
cbb455c
 
5d520be
 
cbb455c
 
5d520be
cbb455c
5d520be
 
 
 
 
 
cbb455c
5d520be
 
 
cbb455c
5d520be
 
 
 
 
cbb455c
5d520be
 
 
 
 
 
 
 
cbb455c
 
5d520be
 
cbb455c
 
5d520be
cbb455c
 
5d520be
cbb455c
5d520be
 
 
cbb455c
 
5d520be
 
 
 
 
 
 
cbb455c
 
5d520be
cbb455c
 
5d520be
cbb455c
5d520be
cbb455c
 
5d520be
cbb455c
5d520be
 
cbb455c
5d520be
 
 
 
 
 
cbb455c
5d520be
 
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb455c
 
5d520be
 
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
cbb455c
5d520be
 
 
 
 
 
 
cbb455c
5d520be
 
cbb455c
5d520be
 
 
 
 
 
cbb455c
5d520be
 
cbb455c
5d520be
 
 
cbb455c
 
5d520be
 
 
 
 
cbb455c
5d520be
cbb455c
 
5d520be
 
 
 
cbb455c
5d520be
 
 
 
 
 
 
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb455c
 
5d520be
 
 
cbb455c
7fdc0b2
5d520be
 
 
 
 
cbb455c
5d520be
 
 
 
 
 
 
 
7fdc0b2
5d520be
 
 
 
 
 
 
 
 
 
 
 
cbb455c
5d520be
 
 
 
cbb455c
 
5d520be
 
 
cbb455c
5d520be
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
 
cbb455c
5d520be
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb455c
5d520be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbb455c
5d520be
 
 
cbb455c
 
5d520be
 
 
 
 
 
 
 
 
 
cbb455c
5d520be
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
"""
LLM Agent Graph Implementation
=============================
This module defines a graph-based LLM agent workflow with various tools and retrieval capabilities.

The agent can:
- Perform mathematical operations
- Search Wikipedia, web, and arXiv
- Retrieve similar questions from a vector database
- Process user queries using different LLM providers

Components:
- Tool definitions: Math operations, search tools
- Vector database retrieval
- Graph construction with different LLM options
- Workflow management with LangGraph
"""

import os
import logging
from typing import Dict, List, Union, Optional, Any, Callable

from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
from langchain_community.vectorstores import SupabaseVectorStore
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.tools import tool
from langchain.tools.retriever import create_retriever_tool
from supabase.client import Client, create_client

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)
logger = logging.getLogger(__name__)

# Load environment variables
load_dotenv()


# ===================
# Math Operation Tools
# ===================

@tool
def multiply(a: int, b: int) -> int:
    """Multiply two integers and return the result.
    
    Args:
        a: First integer to multiply
        b: Second integer to multiply
        
    Returns:
        The product of a and b
    """
    return a * b


@tool
def add(a: int, b: int) -> int:
    """Add two integers and return the result.
    
    Args:
        a: First integer to add
        b: Second integer to add
        
    Returns:
        The sum of a and b
    """
    return a + b


@tool
def subtract(a: int, b: int) -> int:
    """Subtract the second integer from the first and return the result.
    
    Args:
        a: Integer to subtract from
        b: Integer to subtract
        
    Returns:
        The difference (a - b)
    """
    return a - b


@tool
def divide(a: int, b: int) -> float:
    """Divide the first integer by the second and return the result.
    
    Args:
        a: Numerator (dividend)
        b: Denominator (divisor)
        
    Returns:
        The quotient (a / b) as a float
        
    Raises:
        ValueError: If b is zero (division by zero)
    """
    if b == 0:
        raise ValueError("Cannot divide by zero.")
    return a / b


@tool
def modulus(a: int, b: int) -> int:
    """Calculate the remainder when the first integer is divided by the second.
    
    Args:
        a: Dividend
        b: Divisor
        
    Returns:
        The remainder of a divided by b
        
    Raises:
        ValueError: If b is zero (modulo by zero)
    """
    if b == 0:
        raise ValueError("Cannot calculate modulus with divisor zero.")
    return a % b


# ===================
# Search Tools
# ===================

@tool
def wiki_search(query: str) -> Dict[str, str]:
    """Search Wikipedia for a query and return formatted results.
    
    Args:
        query: The search term to look up on Wikipedia
        
    Returns:
        Dictionary with formatted Wikipedia search results
    """
    logger.info(f"Searching Wikipedia for: {query}")
    
    try:
        search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
        
        if not search_docs:
            return {"wiki_results": "No Wikipedia results found for this query."}
        
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
                for doc in search_docs
            ]
        )
        
        logger.info(f"Found {len(search_docs)} Wikipedia results")
        return {"wiki_results": formatted_search_docs}
        
    except Exception as e:
        logger.error(f"Error searching Wikipedia: {e}", exc_info=True)
        return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}


@tool
def web_search(query: str) -> Dict[str, str]:
    """Search the web using Tavily for a query and return formatted results.
    
    Args:
        query: The search term to look up on the web
        
    Returns:
        Dictionary with formatted web search results
    """
    logger.info(f"Searching the web for: {query}")
    
    try:
        search_results = TavilySearchResults(max_results=3).invoke(query=query)
        
        if not search_results:
            return {"web_results": "No web results found for this query."}
        
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="{result["url"]}">\n{result["content"]}\n</Document>'
                for result in search_results
            ]
        )
        
        logger.info(f"Found {len(search_results)} web search results")
        return {"web_results": formatted_search_docs}
        
    except Exception as e:
        logger.error(f"Error searching the web: {e}", exc_info=True)
        return {"web_results": f"Error searching the web: {str(e)}"}


@tool
def arxiv_search(query: str) -> Dict[str, str]:
    """Search arXiv for academic papers and return formatted results.
    
    Args:
        query: The search term to look up on arXiv
        
    Returns:
        Dictionary with formatted arXiv search results
    """
    logger.info(f"Searching arXiv for: {query}")
    
    try:
        search_docs = ArxivLoader(query=query, load_max_docs=3).load()
        
        if not search_docs:
            return {"arxiv_results": "No arXiv results found for this query."}
        
        formatted_search_docs = "\n\n---\n\n".join(
            [
                f'<Document source="{doc.metadata["entry_id"]}" title="{doc.metadata.get("Title", "")}">\n{doc.page_content[:1000]}\n</Document>'
                for doc in search_docs
            ]
        )
        
        logger.info(f"Found {len(search_docs)} arXiv results")
        return {"arxiv_results": formatted_search_docs}
        
    except Exception as e:
        logger.error(f"Error searching arXiv: {e}", exc_info=True)
        return {"arxiv_results": f"Error searching arXiv: {str(e)}"}


# ===================
# Vector Store Setup
# ===================

def setup_vector_store() -> SupabaseVectorStore:
    """
    Set up and configure the Supabase vector store for question retrieval.
    
    Returns:
        Configured SupabaseVectorStore instance
    
    Raises:
        ValueError: If required environment variables are missing
    """
    # Check for required environment variables
    supabase_url = os.environ.get("SUPABASE_URL")
    supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
    
    if not supabase_url or not supabase_key:
        raise ValueError(
            "Missing required environment variables: SUPABASE_URL and/or SUPABASE_SERVICE_KEY"
        )
    
    # Initialize embeddings model
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
    
    # Initialize Supabase client
    supabase_client: Client = create_client(supabase_url, supabase_key)
    
    # Create vector store
    vector_store = SupabaseVectorStore(
        client=supabase_client,
        embedding=embeddings,
        table_name="documents",
        query_name="match_documents_langchain",
    )
    
    logger.info("Vector store initialized successfully")
    return vector_store


# ===================
# LLM Provider Setup
# ===================

def get_llm(provider: str = "google"):
    """
    Initialize and return an LLM based on the specified provider.
    
    Args:
        provider: The LLM provider to use ('google', 'groq', or 'huggingface')
        
    Returns:
        Initialized LLM instance
        
    Raises:
        ValueError: If an invalid provider is specified
    """
    if provider == "google":
        logger.info("Using Google Gemini as LLM provider")
        return ChatGoogleGenerativeAI(model="gemini-2.5-flash-preview-04-17", temperature=0)
    
    elif provider == "groq":
        logger.info("Using Groq as LLM provider with qwen-qwq-32b model")
        return ChatGroq(model="qwen-qwq-32b", temperature=0)
    
    elif provider == "huggingface":
        logger.info("Using Hugging Face as LLM provider with llama-2-7b-chat-hf model")
        return ChatHuggingFace(
            llm=HuggingFaceEndpoint(
                url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
                temperature=0,
            ),
        )
    
    else:
        available_providers = ['google', 'groq', 'huggingface']
        raise ValueError(f"Invalid provider: '{provider}'. Choose from {available_providers}")


# ===================
# Graph Building
# ===================

def build_graph(provider: str = "groq"):
    """
    Build and compile the agent workflow graph.
    
    This function creates a LangGraph workflow that includes:
    - A retriever node to find similar questions
    - An assistant node that uses an LLM to generate responses
    - A tools node for executing various tools
    
    Args:
        provider: The LLM provider to use ('google', 'groq', or 'huggingface')
        
    Returns:
        Compiled StateGraph ready for execution
    """
    logger.info(f"Building agent graph with {provider} as LLM provider")
    
    # Load system prompt
    try:
        with open("system_prompt.txt", "r", encoding="utf-8") as f:
            system_prompt = f.read()
            logger.info("Loaded system prompt from file")
    except FileNotFoundError:
        system_prompt = """You are a helpful AI assistant that answers questions accurately and concisely.
Use the available tools when appropriate to find information or perform calculations.
Always cite your sources when you use search tools."""
        logger.warning("system_prompt.txt not found, using default system prompt")
    
    # Initialize system message
    sys_msg = SystemMessage(content=system_prompt)
    
    # Set up vector store and retriever tool
    try:
        vector_store = setup_vector_store()
        retriever_tool = create_retriever_tool(
            retriever=vector_store.as_retriever(),
            name="Question Search",
            description="A tool to retrieve similar questions from a vector store.",
        )
        logger.info("Vector store retrieval tool initialized")
    except Exception as e:
        logger.error(f"Failed to set up vector store: {e}", exc_info=True)
        retriever_tool = None
    
    # Define available tools
    tools = [
        multiply,
        add,
        subtract,
        divide,
        modulus,
        wiki_search,
        web_search,
        arxiv_search,
    ]
    
    # Add retriever tool if available
    if retriever_tool:
        tools.append(retriever_tool)
    
    # Get LLM and bind tools
    llm = get_llm(provider)
    llm_with_tools = llm.bind_tools(tools)
    
    # Define graph nodes
    def assistant(state: MessagesState) -> Dict[str, List]:
        """
        Assistant node that processes messages with the LLM.
        
        Args:
            state: Current message state
            
        Returns:
            Updated message state with LLM response
        """
        return {"messages": [llm_with_tools.invoke(state["messages"])]}
    
    def retriever(state: MessagesState) -> Dict[str, List]:
        """
        Retriever node that finds similar questions from the vector store.
        
        Args:
            state: Current message state
            
        Returns:
            Updated message state with retrieved examples
        """
        # Only use retrieval if vector_store is available
        if vector_store:
            try:
                similar_questions = vector_store.similarity_search(state["messages"][0].content)
                if similar_questions:
                    example_msg = HumanMessage(
                        content=f"Here I provide a similar question and answer for reference: \n\n{similar_questions[0].page_content}",
                    )
                    return {"messages": [sys_msg] + state["messages"] + [example_msg]}
            except Exception as e:
                logger.error(f"Error in retriever node: {e}", exc_info=True)
        
        # If vector_store is unavailable or retrieval fails, just add system message
        return {"messages": [sys_msg] + state["messages"]}
    
    # Build graph
    builder = StateGraph(MessagesState)
    
    # Add nodes
    builder.add_node("retriever", retriever)
    builder.add_node("assistant", assistant)
    builder.add_node("tools", ToolNode(tools))
    
    # Add edges
    builder.add_edge(START, "retriever")
    builder.add_edge("retriever", "assistant")
    builder.add_conditional_edges(
        "assistant",
        tools_condition,
    )
    builder.add_edge("tools", "assistant")
    
    # Compile graph
    compiled_graph = builder.compile()
    logger.info("Agent graph compiled successfully")
    
    return compiled_graph


# ===================
# Testing
# ===================

if __name__ == "__main__":
    test_question = "When was the wiki entry of Boethius on De Philosophiae Consolatione first added?"
    
    # Build the graph
    logger.info("Starting test run")
    graph = build_graph(provider="groq")
    
    # Run the graph
    logger.info(f"Testing with question: {test_question}")
    messages = [HumanMessage(content=test_question)]
    result_messages = graph.invoke({"messages": messages})
    
    # Display results
    logger.info("Test completed, printing messages:")
    for message in result_messages["messages"]:
        message.pretty_print()