""" NL to SQL Agent implementation """ import json import os from typing import List, Dict, Any, Optional, Tuple import asyncio from openai import OpenAI from galileo import GalileoLogger from galileo.projects import Projects from galileo.stages import Stages from galileo.protect import Protect from galileo_core.schemas.protect.action import OverrideAction from galileo_core.schemas.protect.payload import Payload from galileo_core.schemas.protect.ruleset import Ruleset from galileo_core.schemas.protect.rule import Rule, RuleOperator from tools import ListTablesTool, FetchTableSchemaTool, ExecuteSQLTool from prompts.system_prompt import SYSTEM_PROMPT from data.database import setup_database class NLToSQLAgent: """ Agent that converts natural language queries to SQL using OpenAI. """ def __init__(self): """Initialize the NL to SQL agent.""" # Ensure database is set up setup_database() self.list_tables_tool = ListTablesTool() self.fetch_table_schema_tool = FetchTableSchemaTool() self.execute_sql_tool = ExecuteSQLTool() # Get API key from environment variable api_key = os.environ.get("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable is not set") # Initialize OpenAI client self.client = OpenAI(api_key=api_key) # This will log to the project and log stream specified in the logger constructor self.logger = GalileoLogger(project=os.getenv("GALILEO_PROJECT"), log_stream=os.getenv("GALILEO_LOG_STREAM")) session_id = self.logger.start_session(name="NLToSQLAgent") self.logger.set_session(session_id) project = Projects().get(name=os.getenv("GALILEO_PROJECT")) self.project_id = project.id stage = Stages().get(stage_name="protect-stage", project_id=self.project_id) if stage is None: stage = Stages().create(name="protect-stage", project_id=self.project_id) self.stage_id = stage.id def run_protect(self, query: str, sql_query: str) -> str: """ Run the protect stage. """ response = asyncio.run(Protect().ainvoke( payload=Payload(input=query, output=sql_query), prioritized_rulesets=[ Ruleset( rules=[ Rule( metric="input_pii", operator=RuleOperator.any, target_value=["email","phone_number","address","name","ssn","credit_card_info","account_info","username", "password"], ) ], action=OverrideAction( choices=["Sorry, the input contains PII. Please do not disclose this information."] ), ), Ruleset( rules=[ Rule( metric="prompt_injection", operator=RuleOperator.any, target_value=["impersonation","obfuscation","simple_instruction","few_shot","new_context"], ), ], action=OverrideAction( choices=["Sorry, the input contains prompt injection. I cannot answer this due to safety concerns. Please try again."] ), ) ], stage_id=self.stage_id, )) self.logger.add_protect_span( payload=Payload(input=query, output=sql_query), response=response, ) return response.text def generate_and_execute_sql(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]: """ Convert a natural language query to SQL and execute it. Args: query: The natural language query to convert. Returns: Tuple of (SQL query, query results) """ try: print("Starting to process query:", query) # Use LLM with tools to handle the workflow trace = self.logger.start_trace(query) sql_query, results = self.run_llm_with_tools(query, run_protect) print("SQL query generated:", sql_query) print("Query results:", results) return sql_query, results except Exception as e: print(f"Error generating SQL query: {e}") raise ValueError(f"Failed to generate SQL query: {e}") def generate_sql_query(self, query: str) -> str: """ Convert a natural language query to SQL (without executing it). Args: query: The natural language query to convert. Returns: The generated SQL query. """ sql_query, _ = self.generate_and_execute_sql(query) return sql_query def run_llm_with_tools(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]: """ Use the LLM with tools to handle the entire process. Args: query: The natural language query. Returns: Tuple of (SQL query, query results) """ try: print("Setting up LLM with tools...") # Define tools for the model tools = [ self.list_tables_tool.to_dict(), self.fetch_table_schema_tool.to_dict(), self.execute_sql_tool.to_dict() ] # Initial system and user messages messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": query} ] final_sql_query = "" query_results = [] continue_conversation = True while continue_conversation: print("Calling OpenAI API...") # Call OpenAI API response = self.client.chat.completions.create( model="gpt-4o-mini", # Use the appropriate model messages=messages, tools=tools, tool_choice="auto", temperature=1.0 ) print("Received response from OpenAI") response_message = response.choices[0].message print("Response message:", response_message) self.logger.add_llm_span( input=messages, output=response_message.model_dump(), tools=tools, model="gpt-4o-mini", ) # Add the assistant's message to the conversation messages.append(response_message.model_dump()) # Check if the model wants to call a tool if response_message.tool_calls: print("Tool call requested by LLM") # Handle each tool call for tool_call in response_message.tool_calls: function_name = tool_call.function.name function_args = json.loads(tool_call.function.arguments) # Execute the appropriate tool if function_name == "list_tables": print("Executing list_tables tool") tables = self.list_tables_tool.call() function_response = json.dumps(tables) elif function_name == "fetch_table_schema": table_name = function_args.get("table_name") print(f"Executing fetch_table_schema for {table_name}") schema = self.fetch_table_schema_tool.call(table_name) # Convert schema to a JSON-serializable format if schema: serialized_schema = { "table_name": schema.table_name, "columns": [ {"name": col.name, "type": col.type, "description": col.description} for col in schema.columns ] } function_response = json.dumps(serialized_schema) else: function_response = "Schema not found" elif function_name == "execute_sql": sql = function_args.get("query") print(f"Executing SQL query: {sql}") # Save the SQL query if not final_sql_query: final_sql_query = sql # Execute the query result = self.execute_sql_tool.call(sql) function_response = json.dumps(result) # Save the results if successful if result.get("success", False): query_results = result.get("results", []) else: function_response = "Function not found" self.logger.add_tool_span( name=function_name, input=tool_call.function.arguments, output=function_response, ) # Add the tool response to the conversation messages.append({ "role": "tool", "tool_call_id": tool_call.id, "content": function_response }) print("Tool response added to conversation") else: # If no tool calls, the model has generated the final response print("Final response received from LLM") # If we have a SQL query but no results, we need to extract and execute it if not final_sql_query and response_message.content: extracted_sql = self.clean_sql_response(response_message.content) if extracted_sql: final_sql_query = extracted_sql print(f"Executing extracted SQL: {final_sql_query}") result = self.execute_sql_tool.call(final_sql_query) if result.get("success", False): query_results = result.get("results", []) continue_conversation = False # Clean the SQL response if needed if not final_sql_query and response_message.content: final_sql_query = self.clean_sql_response(response_message.content) # Only run protection if the toggle is enabled if run_protect: final_sql_query = self.run_protect(query, final_sql_query) if len(query_results) > 0: self.logger.conclude(output=str(query_results)) else: self.logger.conclude(output=str(final_sql_query)) self.logger.flush() return final_sql_query, query_results except Exception as e: print(f"Error in run_llm_with_tools: {e}") raise ValueError(f"Failed to generate SQL with tools: {e}") def clean_sql_response(self, response: str) -> str: """ Clean up the LLM response to ensure we only return the SQL query. Args: response: The raw LLM response. Returns: The cleaned SQL query. """ if not response: return "" # Remove any markdown SQL code blocks if "```sql" in response: # Extract content between SQL code blocks import re sql_blocks = re.findall(r"```sql\n(.*?)\n```", response, re.DOTALL) if sql_blocks: return sql_blocks[0].strip() # If no code blocks, remove explanatory text and keep SQL sql_lines = [] in_sql = False for line in response.split("\n"): line_lower = line.lower().strip() # Skip explanatory lines at the beginning if not in_sql and not any(keyword in line_lower for keyword in ["select", "from", "where", "group", "order", "having", "join", "update", "delete", "insert"]): continue # Once we hit SQL, include all lines in_sql = True sql_lines.append(line) # If we filtered out everything, return the original response if not sql_lines and response: return response.strip() return "\n".join(sql_lines).strip()