File size: 13,771 Bytes
2e83155
 
 
 
 
 
e31483c
2e83155
 
e31483c
 
 
 
6070aec
 
 
 
2e83155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d9b35ea
fbfd891
e31483c
6070aec
 
e31483c
6070aec
e31483c
6070aec
 
 
 
 
 
e31483c
6070aec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e31483c
b8ede5c
 
 
 
6070aec
2e83155
 
6070aec
2e83155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3c5bb6
2e83155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3c5bb6
2e83155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2de90c
322fa63
2e83155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3c5bb6
 
 
 
 
9a15884
 
 
 
 
2e83155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
"""
NL to SQL Agent implementation
"""
import json
import os
from typing import List, Dict, Any, Optional, Tuple
import asyncio

from openai import OpenAI
from galileo import GalileoLogger
from galileo.projects import Projects
from galileo.stages import Stages
from galileo.protect import Protect
from galileo_core.schemas.protect.action import OverrideAction
from galileo_core.schemas.protect.payload import Payload
from galileo_core.schemas.protect.ruleset import Ruleset
from galileo_core.schemas.protect.rule import Rule, RuleOperator
from tools import ListTablesTool, FetchTableSchemaTool, ExecuteSQLTool
from prompts.system_prompt import SYSTEM_PROMPT
from data.database import setup_database


class NLToSQLAgent:
    """
    Agent that converts natural language queries to SQL using OpenAI.
    """

    def __init__(self):
        """Initialize the NL to SQL agent."""
        # Ensure database is set up
        setup_database()
        
        self.list_tables_tool = ListTablesTool()
        self.fetch_table_schema_tool = FetchTableSchemaTool()
        self.execute_sql_tool = ExecuteSQLTool()
        
        # Get API key from environment variable
        api_key = os.environ.get("OPENAI_API_KEY")
        if not api_key:
            raise ValueError("OPENAI_API_KEY environment variable is not set")
        
        # Initialize OpenAI client
        self.client = OpenAI(api_key=api_key)

        # This will log to the project and log stream specified in the logger constructor
        self.logger = GalileoLogger(project=os.getenv("GALILEO_PROJECT"), log_stream=os.getenv("GALILEO_LOG_STREAM"))
        session_id = self.logger.start_session(name="NLToSQLAgent")
        self.logger.set_session(session_id)
        project = Projects().get(name=os.getenv("GALILEO_PROJECT"))
        self.project_id = project.id

        stage = Stages().get(stage_name="protect-stage", project_id=self.project_id)
        if stage is None:
            stage = Stages().create(name="protect-stage", project_id=self.project_id)
        self.stage_id = stage.id

    def run_protect(self, query: str, sql_query: str) -> str:
        """
        Run the protect stage.
        """
        response = asyncio.run(Protect().ainvoke(
            payload=Payload(input=query, output=sql_query),
            prioritized_rulesets=[
                Ruleset(
                    rules=[
                        Rule(
                            metric="input_pii",
                            operator=RuleOperator.any,
                            target_value=["email","phone_number","address","name","ssn","credit_card_info","account_info","username", "password"],
                        )
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the input contains PII. Please do not disclose this information."]
                    ),
                ),
                Ruleset(
                    rules=[
                        Rule(
                            metric="prompt_injection",
                            operator=RuleOperator.any,
                            target_value=["impersonation","obfuscation","simple_instruction","few_shot","new_context"],
                        ),
                    ],
                    action=OverrideAction(
                        choices=["Sorry, the input contains prompt injection. I cannot answer this due to safety concerns. Please try again."]
                    ),
                )
            ],
            stage_id=self.stage_id,
        ))
        self.logger.add_protect_span(
            payload=Payload(input=query, output=sql_query),
            response=response,
        )
        return response.text

    
    def generate_and_execute_sql(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
        """
        Convert a natural language query to SQL and execute it.
        
        Args:
            query: The natural language query to convert.
            
        Returns:
            Tuple of (SQL query, query results)
        """
        try:
            print("Starting to process query:", query)
            
            # Use LLM with tools to handle the workflow
            trace = self.logger.start_trace(query)

            sql_query, results = self.run_llm_with_tools(query, run_protect)
            print("SQL query generated:", sql_query)
            print("Query results:", results)
            
            return sql_query, results
        except Exception as e:
            print(f"Error generating SQL query: {e}")
            raise ValueError(f"Failed to generate SQL query: {e}")
    
    def generate_sql_query(self, query: str) -> str:
        """
        Convert a natural language query to SQL (without executing it).
        
        Args:
            query: The natural language query to convert.
            
        Returns:
            The generated SQL query.
        """
        sql_query, _ = self.generate_and_execute_sql(query)
        return sql_query
    
    def run_llm_with_tools(self, query: str, run_protect: bool = False) -> Tuple[str, List[Dict[str, Any]]]:
        """
        Use the LLM with tools to handle the entire process.
        
        Args:
            query: The natural language query.
            
        Returns:
            Tuple of (SQL query, query results)
        """
        try:
            print("Setting up LLM with tools...")
            
            # Define tools for the model
            tools = [
                self.list_tables_tool.to_dict(),
                self.fetch_table_schema_tool.to_dict(),
                self.execute_sql_tool.to_dict()
            ]
            
            # Initial system and user messages
            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": query}
            ]
            
            final_sql_query = ""
            query_results = []
            continue_conversation = True
            
            while continue_conversation:
                print("Calling OpenAI API...")
                
                # Call OpenAI API
                response = self.client.chat.completions.create(
                    model="gpt-4o-mini",  # Use the appropriate model
                    messages=messages,
                    tools=tools,
                    tool_choice="auto",
                    temperature=1.0
                )
                
                print("Received response from OpenAI")
                response_message = response.choices[0].message
                print("Response message:", response_message)

                self.logger.add_llm_span(
                    input=messages,
                    output=response_message.model_dump(),
                    tools=tools,
                    model="gpt-4o-mini",
                )
                
                # Add the assistant's message to the conversation
                messages.append(response_message.model_dump())
                
                # Check if the model wants to call a tool
                if response_message.tool_calls:
                    print("Tool call requested by LLM")
                    
                    # Handle each tool call
                    for tool_call in response_message.tool_calls:
                        function_name = tool_call.function.name
                        function_args = json.loads(tool_call.function.arguments)
                        
                        # Execute the appropriate tool
                        if function_name == "list_tables":
                            print("Executing list_tables tool")
                            tables = self.list_tables_tool.call()
                            function_response = json.dumps(tables)

                        elif function_name == "fetch_table_schema":
                            table_name = function_args.get("table_name")
                            print(f"Executing fetch_table_schema for {table_name}")
                            schema = self.fetch_table_schema_tool.call(table_name)
                            
                            # Convert schema to a JSON-serializable format
                            if schema:
                                serialized_schema = {
                                    "table_name": schema.table_name,
                                    "columns": [
                                        {"name": col.name, "type": col.type, "description": col.description}
                                        for col in schema.columns
                                    ]
                                }
                                function_response = json.dumps(serialized_schema)
                            else:
                                function_response = "Schema not found"
                        
                        elif function_name == "execute_sql":
                            sql = function_args.get("query")
                            print(f"Executing SQL query: {sql}")
                            
                            # Save the SQL query
                            if not final_sql_query:
                                final_sql_query = sql
                            
                            # Execute the query
                            result = self.execute_sql_tool.call(sql)
                            function_response = json.dumps(result)
                            
                            # Save the results if successful
                            if result.get("success", False):
                                query_results = result.get("results", [])
                        
                        else:
                            function_response = "Function not found"
                        
                        self.logger.add_tool_span(
                            name=function_name,
                            input=tool_call.function.arguments,
                            output=function_response,
                        )
                        # Add the tool response to the conversation
                        messages.append({
                            "role": "tool",
                            "tool_call_id": tool_call.id,
                            "content": function_response
                        })
                        print("Tool response added to conversation")
                else:
                    # If no tool calls, the model has generated the final response
                    print("Final response received from LLM")
                    
                    # If we have a SQL query but no results, we need to extract and execute it
                    if not final_sql_query and response_message.content:
                        extracted_sql = self.clean_sql_response(response_message.content)
                        if extracted_sql:
                            final_sql_query = extracted_sql
                            print(f"Executing extracted SQL: {final_sql_query}")
                            result = self.execute_sql_tool.call(final_sql_query)
                            if result.get("success", False):
                                query_results = result.get("results", [])
                    
                    continue_conversation = False
            # Clean the SQL response if needed
            if not final_sql_query and response_message.content:
                final_sql_query = self.clean_sql_response(response_message.content)
            
            # Only run protection if the toggle is enabled
            if run_protect:
                final_sql_query = self.run_protect(query, final_sql_query)
           
            if len(query_results) > 0:
                self.logger.conclude(output=str(query_results))
            else:
                self.logger.conclude(output=str(final_sql_query))
            self.logger.flush()
            
            return final_sql_query, query_results
        except Exception as e:
            print(f"Error in run_llm_with_tools: {e}")
            raise ValueError(f"Failed to generate SQL with tools: {e}")
    
    def clean_sql_response(self, response: str) -> str:
        """
        Clean up the LLM response to ensure we only return the SQL query.
        
        Args:
            response: The raw LLM response.
            
        Returns:
            The cleaned SQL query.
        """
        if not response:
            return ""
        
        # Remove any markdown SQL code blocks
        if "```sql" in response:
            # Extract content between SQL code blocks
            import re
            sql_blocks = re.findall(r"```sql\n(.*?)\n```", response, re.DOTALL)
            if sql_blocks:
                return sql_blocks[0].strip()
        
        # If no code blocks, remove explanatory text and keep SQL
        sql_lines = []
        in_sql = False
        
        for line in response.split("\n"):
            line_lower = line.lower().strip()
            
            # Skip explanatory lines at the beginning
            if not in_sql and not any(keyword in line_lower for keyword in 
                                    ["select", "from", "where", "group", "order", "having", 
                                     "join", "update", "delete", "insert"]):
                continue
            
            # Once we hit SQL, include all lines
            in_sql = True
            sql_lines.append(line)
        
        # If we filtered out everything, return the original response
        if not sql_lines and response:
            return response.strip()
        
        return "\n".join(sql_lines).strip()