Spaces:
Sleeping
Sleeping
| """ | |
| NL to SQL Agent implementation | |
| """ | |
| import json | |
| import os | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import asyncio | |
| from openai import OpenAI | |
| from galileo import GalileoLogger | |
| from galileo.projects import Projects | |
| from galileo.stages import Stages | |
| from galileo.protect import Protect | |
| from galileo_core.schemas.protect.action import OverrideAction | |
| from galileo_core.schemas.protect.payload import Payload | |
| from galileo_core.schemas.protect.ruleset import Ruleset | |
| from galileo_core.schemas.protect.rule import Rule, RuleOperator | |
| from tools import ListTablesTool, FetchTableSchemaTool, ExecuteSQLTool | |
| from prompts.system_prompt import SYSTEM_PROMPT | |
| from data.database import setup_database | |
| class NLToSQLAgent: | |
| """ | |
| Agent that converts natural language queries to SQL using OpenAI. | |
| """ | |
| def __init__(self): | |
| """Initialize the NL to SQL agent.""" | |
| # Ensure database is set up | |
| setup_database() | |
| self.list_tables_tool = ListTablesTool() | |
| self.fetch_table_schema_tool = FetchTableSchemaTool() | |
| self.execute_sql_tool = ExecuteSQLTool() | |
| # Get API key from environment variable | |
| api_key = os.environ.get("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY environment variable is not set") | |
| # Initialize OpenAI client | |
| self.client = OpenAI(api_key=api_key) | |
| # This will log to the project and log stream specified in the logger constructor | |
| self.logger = GalileoLogger(project=os.getenv("GALILEO_PROJECT"), log_stream=os.getenv("GALILEO_LOG_STREAM")) | |
| session_id = self.logger.start_session(name="NLToSQLAgent") | |
| self.logger.set_session(session_id) | |
| project = Projects().get(name=os.getenv("GALILEO_PROJECT")) | |
| self.project_id = project.id | |
| stage = Stages().get(stage_name="protect-stage", project_id=self.project_id) | |
| if stage is None: | |
| stage = Stages().create(name="protect-stage", project_id=self.project_id) | |
| self.stage_id = stage.id | |
| def run_protect(self, query: str, sql_query: str) -> str: | |
| """ | |
| Run the protect stage. | |
| """ | |
| response = asyncio.run(Protect().ainvoke( | |
| payload=Payload(input=query, output=sql_query), | |
| prioritized_rulesets=[ | |
| Ruleset( | |
| rules=[ | |
| Rule( | |
| metric="input_pii", | |
| operator=RuleOperator.any, | |
| target_value=["email","phone_number","address","name","ssn","credit_card_info","account_info","username", "password"], | |
| ) | |
| ], | |
| action=OverrideAction( | |
| choices=["Sorry, the input contains PII. Please do not disclose this information."] | |
| ), | |
| ), | |
| Ruleset( | |
| rules=[ | |
| Rule( | |
| metric="prompt_injection", | |
| operator=RuleOperator.any, | |
| target_value=["impersonation","obfuscation","simple_instruction","few_shot","new_context"], | |
| ), | |
| ], | |
| action=OverrideAction( | |
| choices=["Sorry, the input contains prompt injection. I cannot answer this due to safety concerns. Please try again."] | |
| ), | |
| ) | |
| ], | |
| stage_id=self.stage_id, | |
| )) | |
| self.logger.add_protect_span( | |
| payload=Payload(input=query, output=sql_query), | |
| response=response, | |
| ) | |
| return response.text | |
| def generate_and_execute_sql(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]: | |
| """ | |
| Convert a natural language query to SQL and execute it. | |
| Args: | |
| query: The natural language query to convert. | |
| Returns: | |
| Tuple of (SQL query, query results) | |
| """ | |
| try: | |
| print("Starting to process query:", query) | |
| # Use LLM with tools to handle the workflow | |
| trace = self.logger.start_trace(query) | |
| sql_query, results = self.run_llm_with_tools(query, run_protect) | |
| print("SQL query generated:", sql_query) | |
| print("Query results:", results) | |
| return sql_query, results | |
| except Exception as e: | |
| print(f"Error generating SQL query: {e}") | |
| raise ValueError(f"Failed to generate SQL query: {e}") | |
| def generate_sql_query(self, query: str) -> str: | |
| """ | |
| Convert a natural language query to SQL (without executing it). | |
| Args: | |
| query: The natural language query to convert. | |
| Returns: | |
| The generated SQL query. | |
| """ | |
| sql_query, _ = self.generate_and_execute_sql(query) | |
| return sql_query | |
| def run_llm_with_tools(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]: | |
| """ | |
| Use the LLM with tools to handle the entire process. | |
| Args: | |
| query: The natural language query. | |
| Returns: | |
| Tuple of (SQL query, query results) | |
| """ | |
| try: | |
| print("Setting up LLM with tools...") | |
| # Define tools for the model | |
| tools = [ | |
| self.list_tables_tool.to_dict(), | |
| self.fetch_table_schema_tool.to_dict(), | |
| self.execute_sql_tool.to_dict() | |
| ] | |
| # Initial system and user messages | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": query} | |
| ] | |
| final_sql_query = "" | |
| query_results = [] | |
| continue_conversation = True | |
| while continue_conversation: | |
| print("Calling OpenAI API...") | |
| # Call OpenAI API | |
| response = self.client.chat.completions.create( | |
| model="gpt-4o-mini", # Use the appropriate model | |
| messages=messages, | |
| tools=tools, | |
| tool_choice="auto", | |
| temperature=1.0 | |
| ) | |
| print("Received response from OpenAI") | |
| response_message = response.choices[0].message | |
| print("Response message:", response_message) | |
| self.logger.add_llm_span( | |
| input=messages, | |
| output=response_message.model_dump(), | |
| tools=tools, | |
| model="gpt-4o-mini", | |
| ) | |
| # Add the assistant's message to the conversation | |
| messages.append(response_message.model_dump()) | |
| # Check if the model wants to call a tool | |
| if response_message.tool_calls: | |
| print("Tool call requested by LLM") | |
| # Handle each tool call | |
| for tool_call in response_message.tool_calls: | |
| function_name = tool_call.function.name | |
| function_args = json.loads(tool_call.function.arguments) | |
| # Execute the appropriate tool | |
| if function_name == "list_tables": | |
| print("Executing list_tables tool") | |
| tables = self.list_tables_tool.call() | |
| function_response = json.dumps(tables) | |
| elif function_name == "fetch_table_schema": | |
| table_name = function_args.get("table_name") | |
| print(f"Executing fetch_table_schema for {table_name}") | |
| schema = self.fetch_table_schema_tool.call(table_name) | |
| # Convert schema to a JSON-serializable format | |
| if schema: | |
| serialized_schema = { | |
| "table_name": schema.table_name, | |
| "columns": [ | |
| {"name": col.name, "type": col.type, "description": col.description} | |
| for col in schema.columns | |
| ] | |
| } | |
| function_response = json.dumps(serialized_schema) | |
| else: | |
| function_response = "Schema not found" | |
| elif function_name == "execute_sql": | |
| sql = function_args.get("query") | |
| print(f"Executing SQL query: {sql}") | |
| # Save the SQL query | |
| if not final_sql_query: | |
| final_sql_query = sql | |
| # Execute the query | |
| result = self.execute_sql_tool.call(sql) | |
| function_response = json.dumps(result) | |
| # Save the results if successful | |
| if result.get("success", False): | |
| query_results = result.get("results", []) | |
| else: | |
| function_response = "Function not found" | |
| self.logger.add_tool_span( | |
| name=function_name, | |
| input=tool_call.function.arguments, | |
| output=function_response, | |
| ) | |
| # Add the tool response to the conversation | |
| messages.append({ | |
| "role": "tool", | |
| "tool_call_id": tool_call.id, | |
| "content": function_response | |
| }) | |
| print("Tool response added to conversation") | |
| else: | |
| # If no tool calls, the model has generated the final response | |
| print("Final response received from LLM") | |
| # If we have a SQL query but no results, we need to extract and execute it | |
| if not final_sql_query and response_message.content: | |
| extracted_sql = self.clean_sql_response(response_message.content) | |
| if extracted_sql: | |
| final_sql_query = extracted_sql | |
| print(f"Executing extracted SQL: {final_sql_query}") | |
| result = self.execute_sql_tool.call(final_sql_query) | |
| if result.get("success", False): | |
| query_results = result.get("results", []) | |
| continue_conversation = False | |
| # Clean the SQL response if needed | |
| if not final_sql_query and response_message.content: | |
| final_sql_query = self.clean_sql_response(response_message.content) | |
| # Only run protection if the toggle is enabled | |
| if run_protect: | |
| final_sql_query = self.run_protect(query, final_sql_query) | |
| if len(query_results) > 0: | |
| self.logger.conclude(output=str(query_results)) | |
| else: | |
| self.logger.conclude(output=str(final_sql_query)) | |
| self.logger.flush() | |
| return final_sql_query, query_results | |
| except Exception as e: | |
| print(f"Error in run_llm_with_tools: {e}") | |
| raise ValueError(f"Failed to generate SQL with tools: {e}") | |
| def clean_sql_response(self, response: str) -> str: | |
| """ | |
| Clean up the LLM response to ensure we only return the SQL query. | |
| Args: | |
| response: The raw LLM response. | |
| Returns: | |
| The cleaned SQL query. | |
| """ | |
| if not response: | |
| return "" | |
| # Remove any markdown SQL code blocks | |
| if "```sql" in response: | |
| # Extract content between SQL code blocks | |
| import re | |
| sql_blocks = re.findall(r"```sql\n(.*?)\n```", response, re.DOTALL) | |
| if sql_blocks: | |
| return sql_blocks[0].strip() | |
| # If no code blocks, remove explanatory text and keep SQL | |
| sql_lines = [] | |
| in_sql = False | |
| for line in response.split("\n"): | |
| line_lower = line.lower().strip() | |
| # Skip explanatory lines at the beginning | |
| if not in_sql and not any(keyword in line_lower for keyword in | |
| ["select", "from", "where", "group", "order", "having", | |
| "join", "update", "delete", "insert"]): | |
| continue | |
| # Once we hit SQL, include all lines | |
| in_sql = True | |
| sql_lines.append(line) | |
| # If we filtered out everything, return the original response | |
| if not sql_lines and response: | |
| return response.strip() | |
| return "\n".join(sql_lines).strip() |