demo-agent-sql-gpt / src /nl_to_sql_agent.py
Vatsal Goel
add protect span
b8ede5c unverified
"""
NL to SQL Agent implementation
"""
import json
import os
from typing import List, Dict, Any, Optional, Tuple
import asyncio
from openai import OpenAI
from galileo import GalileoLogger
from galileo.projects import Projects
from galileo.stages import Stages
from galileo.protect import Protect
from galileo_core.schemas.protect.action import OverrideAction
from galileo_core.schemas.protect.payload import Payload
from galileo_core.schemas.protect.ruleset import Ruleset
from galileo_core.schemas.protect.rule import Rule, RuleOperator
from tools import ListTablesTool, FetchTableSchemaTool, ExecuteSQLTool
from prompts.system_prompt import SYSTEM_PROMPT
from data.database import setup_database
class NLToSQLAgent:
"""
Agent that converts natural language queries to SQL using OpenAI.
"""
def __init__(self):
"""Initialize the NL to SQL agent."""
# Ensure database is set up
setup_database()
self.list_tables_tool = ListTablesTool()
self.fetch_table_schema_tool = FetchTableSchemaTool()
self.execute_sql_tool = ExecuteSQLTool()
# Get API key from environment variable
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
# Initialize OpenAI client
self.client = OpenAI(api_key=api_key)
# This will log to the project and log stream specified in the logger constructor
self.logger = GalileoLogger(project=os.getenv("GALILEO_PROJECT"), log_stream=os.getenv("GALILEO_LOG_STREAM"))
session_id = self.logger.start_session(name="NLToSQLAgent")
self.logger.set_session(session_id)
project = Projects().get(name=os.getenv("GALILEO_PROJECT"))
self.project_id = project.id
stage = Stages().get(stage_name="protect-stage", project_id=self.project_id)
if stage is None:
stage = Stages().create(name="protect-stage", project_id=self.project_id)
self.stage_id = stage.id
def run_protect(self, query: str, sql_query: str) -> str:
"""
Run the protect stage.
"""
response = asyncio.run(Protect().ainvoke(
payload=Payload(input=query, output=sql_query),
prioritized_rulesets=[
Ruleset(
rules=[
Rule(
metric="input_pii",
operator=RuleOperator.any,
target_value=["email","phone_number","address","name","ssn","credit_card_info","account_info","username", "password"],
)
],
action=OverrideAction(
choices=["Sorry, the input contains PII. Please do not disclose this information."]
),
),
Ruleset(
rules=[
Rule(
metric="prompt_injection",
operator=RuleOperator.any,
target_value=["impersonation","obfuscation","simple_instruction","few_shot","new_context"],
),
],
action=OverrideAction(
choices=["Sorry, the input contains prompt injection. I cannot answer this due to safety concerns. Please try again."]
),
)
],
stage_id=self.stage_id,
))
self.logger.add_protect_span(
payload=Payload(input=query, output=sql_query),
response=response,
)
return response.text
def generate_and_execute_sql(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
"""
Convert a natural language query to SQL and execute it.
Args:
query: The natural language query to convert.
Returns:
Tuple of (SQL query, query results)
"""
try:
print("Starting to process query:", query)
# Use LLM with tools to handle the workflow
trace = self.logger.start_trace(query)
sql_query, results = self.run_llm_with_tools(query, run_protect)
print("SQL query generated:", sql_query)
print("Query results:", results)
return sql_query, results
except Exception as e:
print(f"Error generating SQL query: {e}")
raise ValueError(f"Failed to generate SQL query: {e}")
def generate_sql_query(self, query: str) -> str:
"""
Convert a natural language query to SQL (without executing it).
Args:
query: The natural language query to convert.
Returns:
The generated SQL query.
"""
sql_query, _ = self.generate_and_execute_sql(query)
return sql_query
def run_llm_with_tools(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
"""
Use the LLM with tools to handle the entire process.
Args:
query: The natural language query.
Returns:
Tuple of (SQL query, query results)
"""
try:
print("Setting up LLM with tools...")
# Define tools for the model
tools = [
self.list_tables_tool.to_dict(),
self.fetch_table_schema_tool.to_dict(),
self.execute_sql_tool.to_dict()
]
# Initial system and user messages
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": query}
]
final_sql_query = ""
query_results = []
continue_conversation = True
while continue_conversation:
print("Calling OpenAI API...")
# Call OpenAI API
response = self.client.chat.completions.create(
model="gpt-4o-mini", # Use the appropriate model
messages=messages,
tools=tools,
tool_choice="auto",
temperature=1.0
)
print("Received response from OpenAI")
response_message = response.choices[0].message
print("Response message:", response_message)
self.logger.add_llm_span(
input=messages,
output=response_message.model_dump(),
tools=tools,
model="gpt-4o-mini",
)
# Add the assistant's message to the conversation
messages.append(response_message.model_dump())
# Check if the model wants to call a tool
if response_message.tool_calls:
print("Tool call requested by LLM")
# Handle each tool call
for tool_call in response_message.tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
# Execute the appropriate tool
if function_name == "list_tables":
print("Executing list_tables tool")
tables = self.list_tables_tool.call()
function_response = json.dumps(tables)
elif function_name == "fetch_table_schema":
table_name = function_args.get("table_name")
print(f"Executing fetch_table_schema for {table_name}")
schema = self.fetch_table_schema_tool.call(table_name)
# Convert schema to a JSON-serializable format
if schema:
serialized_schema = {
"table_name": schema.table_name,
"columns": [
{"name": col.name, "type": col.type, "description": col.description}
for col in schema.columns
]
}
function_response = json.dumps(serialized_schema)
else:
function_response = "Schema not found"
elif function_name == "execute_sql":
sql = function_args.get("query")
print(f"Executing SQL query: {sql}")
# Save the SQL query
if not final_sql_query:
final_sql_query = sql
# Execute the query
result = self.execute_sql_tool.call(sql)
function_response = json.dumps(result)
# Save the results if successful
if result.get("success", False):
query_results = result.get("results", [])
else:
function_response = "Function not found"
self.logger.add_tool_span(
name=function_name,
input=tool_call.function.arguments,
output=function_response,
)
# Add the tool response to the conversation
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": function_response
})
print("Tool response added to conversation")
else:
# If no tool calls, the model has generated the final response
print("Final response received from LLM")
# If we have a SQL query but no results, we need to extract and execute it
if not final_sql_query and response_message.content:
extracted_sql = self.clean_sql_response(response_message.content)
if extracted_sql:
final_sql_query = extracted_sql
print(f"Executing extracted SQL: {final_sql_query}")
result = self.execute_sql_tool.call(final_sql_query)
if result.get("success", False):
query_results = result.get("results", [])
continue_conversation = False
# Clean the SQL response if needed
if not final_sql_query and response_message.content:
final_sql_query = self.clean_sql_response(response_message.content)
# Only run protection if the toggle is enabled
if run_protect:
final_sql_query = self.run_protect(query, final_sql_query)
if len(query_results) > 0:
self.logger.conclude(output=str(query_results))
else:
self.logger.conclude(output=str(final_sql_query))
self.logger.flush()
return final_sql_query, query_results
except Exception as e:
print(f"Error in run_llm_with_tools: {e}")
raise ValueError(f"Failed to generate SQL with tools: {e}")
def clean_sql_response(self, response: str) -> str:
"""
Clean up the LLM response to ensure we only return the SQL query.
Args:
response: The raw LLM response.
Returns:
The cleaned SQL query.
"""
if not response:
return ""
# Remove any markdown SQL code blocks
if "```sql" in response:
# Extract content between SQL code blocks
import re
sql_blocks = re.findall(r"```sql\n(.*?)\n```", response, re.DOTALL)
if sql_blocks:
return sql_blocks[0].strip()
# If no code blocks, remove explanatory text and keep SQL
sql_lines = []
in_sql = False
for line in response.split("\n"):
line_lower = line.lower().strip()
# Skip explanatory lines at the beginning
if not in_sql and not any(keyword in line_lower for keyword in
["select", "from", "where", "group", "order", "having",
"join", "update", "delete", "insert"]):
continue
# Once we hit SQL, include all lines
in_sql = True
sql_lines.append(line)
# If we filtered out everything, return the original response
if not sql_lines and response:
return response.strip()
return "\n".join(sql_lines).strip()