|
from dotenv import load_dotenv |
|
import os |
|
from sentence_transformers import SentenceTransformer |
|
import gradio as gr |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from groq import Groq |
|
|
|
|
|
load_dotenv() |
|
|
|
api = os.getenv("groq_api_key") |
|
|
|
def create_metadata_embeddings(): |
|
student=""" |
|
Table: student |
|
Columns: |
|
- student_id: an integer representing the unique ID of a student. |
|
- first_name: a string containing the first name of the student. |
|
- last_name: a string containing the last name of the student. |
|
- date_of_birth: a date representing the student's birthdate. |
|
- email: a string for the student's email address. |
|
- phone_number: a string for the student's contact number. |
|
- major: a string representing the student's major field of study. |
|
- year_of_enrollment: an integer for the year the student enrolled. |
|
""" |
|
|
|
employee=""" |
|
Table: employee |
|
Columns: |
|
- employee_id: an integer representing the unique ID of an employee. |
|
- first_name: a string containing the first name of the employee. |
|
- last_name: a string containing the last name of the employee. |
|
- email: a string for the employee's email address. |
|
- department: a string for the department the employee works in. |
|
- position: a string representing the employee's job title. |
|
- salary: a float representing the employee's salary. |
|
- date_of_joining: a date for when the employee joined the college. |
|
""" |
|
|
|
course=""" |
|
Table: course_info |
|
Columns: |
|
- course_id: an integer representing the unique ID of the course. |
|
- course_name: a string containing the course's name. |
|
- course_code: a string for the course's unique code. |
|
- instructor_id: an integer for the ID of the instructor teaching the course. |
|
- department: a string for the department offering the course. |
|
- credits: an integer representing the course credits. |
|
- semester: a string for the semester when the course is offered. |
|
""" |
|
|
|
metadata_list = [student, employee, course] |
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
embeddings = model.encode(metadata_list) |
|
|
|
return embeddings,model,student,employee,course |
|
|
|
def find_best_fit(embeddings,model,user_query,student,employee,course): |
|
query_embedding = model.encode([user_query]) |
|
similarities = cosine_similarity(query_embedding, embeddings) |
|
best_match_table = similarities.argmax() |
|
if(best_match_table==0): |
|
table_metadata=student |
|
elif(best_match_table==1): |
|
table_metadata=employee |
|
else: |
|
table_metadata=course |
|
|
|
return table_metadata |
|
|
|
|
|
|
|
def create_prompt(user_query,table_metadata): |
|
system_prompt=""" |
|
You are a SQL query generator specialized in generating SQL queries for a single table at a time. Your task is to accurately convert natural language queries into SQL statements based on the user's intent and the provided table metadata. |
|
|
|
Rules: |
|
Single Table Only: Assume all queries are related to a single table provided in the metadata. Ignore any references to other tables. |
|
Metadata-Based Validation: Always ensure the generated query matches the table name, columns, and data types provided in the metadata. |
|
User Intent: Accurately capture the user's requirements, such as filters, sorting, or aggregations, as expressed in natural language. |
|
SQL Syntax: Use standard SQL syntax that is compatible with most relational database systems. |
|
|
|
Input Format: |
|
User Query: The user's natural language request. |
|
Table Metadata: The structure of the relevant table, including the table name, column names, and data types. |
|
|
|
Output Format: |
|
SQL Query: A valid SQL query formatted for readability. |
|
Do not output anything else except the SQL query.Not even a single word extra.Ouput the whole query in a single line only. |
|
You are ready to generate SQL queries based on the user input and table metadata. |
|
""" |
|
|
|
|
|
user_prompt=f""" |
|
User Query: {user_query} |
|
Table Metadata: {table_metadata} |
|
""" |
|
|
|
return system_prompt,user_prompt |
|
|
|
|
|
|
|
def generate_output(system_prompt,user_prompt): |
|
client = Groq(api_key=api,) |
|
chat_completion = client.chat.completions.create(messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user","content": user_prompt,}],model="llama3-70b-8192",) |
|
res = chat_completion.choices[0].message.content |
|
|
|
select=res[0:6].lower() |
|
if(select=="select"): |
|
output=res |
|
else: |
|
output="Can't perform the task at the moment." |
|
|
|
return output |
|
|
|
|
|
def response(user_query): |
|
embeddings,model,student,employee,course=create_metadata_embeddings() |
|
|
|
table_metadata=find_best_fit(embeddings,model,user_query,student,employee,course) |
|
|
|
system_prompt,user_prompt=create_prompt(user_query,table_metadata) |
|
|
|
output=generate_output(system_prompt,user_prompt) |
|
|
|
return output |
|
|
|
desc=""" |
|
|
|
There are three tables in the database: |
|
|
|
|
|
Student Table: |
|
The table contains the student's unique ID, first name, last name, date of birth, email address, phone number, major field of study, and year of enrollment. |
|
|
|
|
|
Employee Table: |
|
The table includes the employee's unique ID, first name, last name, email address, department, job position, salary, and date of joining. |
|
|
|
|
|
Course Info Table: |
|
The table holds information about the course's unique ID, name, course code, instructor ID, department offering the course, number of credits, and the semester in which the course is offered. |
|
|
|
""" |
|
|
|
demo = gr.Interface( |
|
fn=response, |
|
inputs=gr.Textbox(label="Please provide the natural language query"), |
|
outputs=gr.Textbox(label="SQL Query"), |
|
title="SQL Query generator", |
|
description=desc |
|
) |
|
|
|
demo.launch(share="True") |