Spaces:
Sleeping
Sleeping
File size: 13,771 Bytes
2e83155 e31483c 2e83155 e31483c 6070aec 2e83155 d9b35ea fbfd891 e31483c 6070aec e31483c 6070aec e31483c 6070aec e31483c 6070aec e31483c b8ede5c 6070aec 2e83155 6070aec 2e83155 a3c5bb6 2e83155 a3c5bb6 2e83155 c2de90c 322fa63 2e83155 a3c5bb6 9a15884 2e83155 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
"""
NL to SQL Agent implementation
"""
import json
import os
from typing import List, Dict, Any, Optional, Tuple
import asyncio
from openai import OpenAI
from galileo import GalileoLogger
from galileo.projects import Projects
from galileo.stages import Stages
from galileo.protect import Protect
from galileo_core.schemas.protect.action import OverrideAction
from galileo_core.schemas.protect.payload import Payload
from galileo_core.schemas.protect.ruleset import Ruleset
from galileo_core.schemas.protect.rule import Rule, RuleOperator
from tools import ListTablesTool, FetchTableSchemaTool, ExecuteSQLTool
from prompts.system_prompt import SYSTEM_PROMPT
from data.database import setup_database
class NLToSQLAgent:
"""
Agent that converts natural language queries to SQL using OpenAI.
"""
def __init__(self):
"""Initialize the NL to SQL agent."""
# Ensure database is set up
setup_database()
self.list_tables_tool = ListTablesTool()
self.fetch_table_schema_tool = FetchTableSchemaTool()
self.execute_sql_tool = ExecuteSQLTool()
# Get API key from environment variable
api_key = os.environ.get("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY environment variable is not set")
# Initialize OpenAI client
self.client = OpenAI(api_key=api_key)
# This will log to the project and log stream specified in the logger constructor
self.logger = GalileoLogger(project=os.getenv("GALILEO_PROJECT"), log_stream=os.getenv("GALILEO_LOG_STREAM"))
session_id = self.logger.start_session(name="NLToSQLAgent")
self.logger.set_session(session_id)
project = Projects().get(name=os.getenv("GALILEO_PROJECT"))
self.project_id = project.id
stage = Stages().get(stage_name="protect-stage", project_id=self.project_id)
if stage is None:
stage = Stages().create(name="protect-stage", project_id=self.project_id)
self.stage_id = stage.id
def run_protect(self, query: str, sql_query: str) -> str:
"""
Run the protect stage.
"""
response = asyncio.run(Protect().ainvoke(
payload=Payload(input=query, output=sql_query),
prioritized_rulesets=[
Ruleset(
rules=[
Rule(
metric="input_pii",
operator=RuleOperator.any,
target_value=["email","phone_number","address","name","ssn","credit_card_info","account_info","username", "password"],
)
],
action=OverrideAction(
choices=["Sorry, the input contains PII. Please do not disclose this information."]
),
),
Ruleset(
rules=[
Rule(
metric="prompt_injection",
operator=RuleOperator.any,
target_value=["impersonation","obfuscation","simple_instruction","few_shot","new_context"],
),
],
action=OverrideAction(
choices=["Sorry, the input contains prompt injection. I cannot answer this due to safety concerns. Please try again."]
),
)
],
stage_id=self.stage_id,
))
self.logger.add_protect_span(
payload=Payload(input=query, output=sql_query),
response=response,
)
return response.text
def generate_and_execute_sql(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
"""
Convert a natural language query to SQL and execute it.
Args:
query: The natural language query to convert.
Returns:
Tuple of (SQL query, query results)
"""
try:
print("Starting to process query:", query)
# Use LLM with tools to handle the workflow
trace = self.logger.start_trace(query)
sql_query, results = self.run_llm_with_tools(query, run_protect)
print("SQL query generated:", sql_query)
print("Query results:", results)
return sql_query, results
except Exception as e:
print(f"Error generating SQL query: {e}")
raise ValueError(f"Failed to generate SQL query: {e}")
def generate_sql_query(self, query: str) -> str:
"""
Convert a natural language query to SQL (without executing it).
Args:
query: The natural language query to convert.
Returns:
The generated SQL query.
"""
sql_query, _ = self.generate_and_execute_sql(query)
return sql_query
def run_llm_with_tools(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
"""
Use the LLM with tools to handle the entire process.
Args:
query: The natural language query.
Returns:
Tuple of (SQL query, query results)
"""
try:
print("Setting up LLM with tools...")
# Define tools for the model
tools = [
self.list_tables_tool.to_dict(),
self.fetch_table_schema_tool.to_dict(),
self.execute_sql_tool.to_dict()
]
# Initial system and user messages
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": query}
]
final_sql_query = ""
query_results = []
continue_conversation = True
while continue_conversation:
print("Calling OpenAI API...")
# Call OpenAI API
response = self.client.chat.completions.create(
model="gpt-4o-mini", # Use the appropriate model
messages=messages,
tools=tools,
tool_choice="auto",
temperature=1.0
)
print("Received response from OpenAI")
response_message = response.choices[0].message
print("Response message:", response_message)
self.logger.add_llm_span(
input=messages,
output=response_message.model_dump(),
tools=tools,
model="gpt-4o-mini",
)
# Add the assistant's message to the conversation
messages.append(response_message.model_dump())
# Check if the model wants to call a tool
if response_message.tool_calls:
print("Tool call requested by LLM")
# Handle each tool call
for tool_call in response_message.tool_calls:
function_name = tool_call.function.name
function_args = json.loads(tool_call.function.arguments)
# Execute the appropriate tool
if function_name == "list_tables":
print("Executing list_tables tool")
tables = self.list_tables_tool.call()
function_response = json.dumps(tables)
elif function_name == "fetch_table_schema":
table_name = function_args.get("table_name")
print(f"Executing fetch_table_schema for {table_name}")
schema = self.fetch_table_schema_tool.call(table_name)
# Convert schema to a JSON-serializable format
if schema:
serialized_schema = {
"table_name": schema.table_name,
"columns": [
{"name": col.name, "type": col.type, "description": col.description}
for col in schema.columns
]
}
function_response = json.dumps(serialized_schema)
else:
function_response = "Schema not found"
elif function_name == "execute_sql":
sql = function_args.get("query")
print(f"Executing SQL query: {sql}")
# Save the SQL query
if not final_sql_query:
final_sql_query = sql
# Execute the query
result = self.execute_sql_tool.call(sql)
function_response = json.dumps(result)
# Save the results if successful
if result.get("success", False):
query_results = result.get("results", [])
else:
function_response = "Function not found"
self.logger.add_tool_span(
name=function_name,
input=tool_call.function.arguments,
output=function_response,
)
# Add the tool response to the conversation
messages.append({
"role": "tool",
"tool_call_id": tool_call.id,
"content": function_response
})
print("Tool response added to conversation")
else:
# If no tool calls, the model has generated the final response
print("Final response received from LLM")
# If we have a SQL query but no results, we need to extract and execute it
if not final_sql_query and response_message.content:
extracted_sql = self.clean_sql_response(response_message.content)
if extracted_sql:
final_sql_query = extracted_sql
print(f"Executing extracted SQL: {final_sql_query}")
result = self.execute_sql_tool.call(final_sql_query)
if result.get("success", False):
query_results = result.get("results", [])
continue_conversation = False
# Clean the SQL response if needed
if not final_sql_query and response_message.content:
final_sql_query = self.clean_sql_response(response_message.content)
# Only run protection if the toggle is enabled
if run_protect:
final_sql_query = self.run_protect(query, final_sql_query)
if len(query_results) > 0:
self.logger.conclude(output=str(query_results))
else:
self.logger.conclude(output=str(final_sql_query))
self.logger.flush()
return final_sql_query, query_results
except Exception as e:
print(f"Error in run_llm_with_tools: {e}")
raise ValueError(f"Failed to generate SQL with tools: {e}")
def clean_sql_response(self, response: str) -> str:
"""
Clean up the LLM response to ensure we only return the SQL query.
Args:
response: The raw LLM response.
Returns:
The cleaned SQL query.
"""
if not response:
return ""
# Remove any markdown SQL code blocks
if "```sql" in response:
# Extract content between SQL code blocks
import re
sql_blocks = re.findall(r"```sql\n(.*?)\n```", response, re.DOTALL)
if sql_blocks:
return sql_blocks[0].strip()
# If no code blocks, remove explanatory text and keep SQL
sql_lines = []
in_sql = False
for line in response.split("\n"):
line_lower = line.lower().strip()
# Skip explanatory lines at the beginning
if not in_sql and not any(keyword in line_lower for keyword in
["select", "from", "where", "group", "order", "having",
"join", "update", "delete", "insert"]):
continue
# Once we hit SQL, include all lines
in_sql = True
sql_lines.append(line)
# If we filtered out everything, return the original response
if not sql_lines and response:
return response.strip()
return "\n".join(sql_lines).strip() |