File size: 6,191 Bytes
c35bbf4
495c53e
 
 
c35bbf4
 
 
 
 
 
495c53e
 
 
c35bbf4
 
495c53e
c35bbf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495c53e
 
 
 
 
 
c35bbf4
 
 
 
 
 
 
495c53e
 
 
c35bbf4
 
 
495c53e
 
 
c35bbf4
 
 
 
495c53e
 
 
c35bbf4
 
 
 
 
 
 
495c53e
 
 
 
 
 
 
 
 
 
 
 
c35bbf4
 
 
495c53e
c35bbf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495c53e
c35bbf4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495c53e
 
 
c35bbf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse
import psutil
import os

# Check available memory
def get_available_memory():
    return psutil.virtual_memory().available

model_name = "defog/llama-3-sqlcoder-8b"

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

# CPU-compatible model loading
def load_model():
    try:
        available_memory = get_available_memory()
        print(f"Available memory: {available_memory / 1e9:.1f} GB")
        
        # For CPU deployment, we'll use float32 or float16 without quantization
        if available_memory > 16e9:  # 16GB+ RAM
            print("Loading model in float16...")
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                trust_remote_code=True,
                torch_dtype=torch.float16,
                device_map="cpu",
                use_cache=True,
                low_cpu_mem_usage=True
            )
        else:
            print("Loading model in float32 with low memory usage...")
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                trust_remote_code=True,
                device_map="cpu",
                use_cache=True,
                low_cpu_mem_usage=True,
                torch_dtype=torch.float32
            )
        
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

# Load model (this will take some time on first run)
print("Loading model... This may take a few minutes on CPU.")
model = load_model()

prompt_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

DDL statements:

CREATE TABLE expenses (
  id INTEGER PRIMARY KEY, -- Unique ID for each expense
  date DATE NOT NULL, -- Date when the expense occurred
  amount DECIMAL(10,2) NOT NULL, -- Amount spent
  category VARCHAR(50) NOT NULL, -- Category of expense (food, transport, utilities, etc.)
  description TEXT, -- Optional description of the expense
  payment_method VARCHAR(20), -- How the payment was made (cash, credit_card, debit_card, bank_transfer)
  user_id INTEGER -- ID of the user who made the expense
);

CREATE TABLE categories (
  id INTEGER PRIMARY KEY, -- Unique ID for each category
  name VARCHAR(50) UNIQUE NOT NULL, -- Category name (food, transport, utilities, entertainment, etc.)
  description TEXT -- Optional description of the category
);

CREATE TABLE users (
  id INTEGER PRIMARY KEY, -- Unique ID for each user
  username VARCHAR(50) UNIQUE NOT NULL, -- Username
  email VARCHAR(100) UNIQUE NOT NULL, -- Email address
  created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- When the user account was created
);

CREATE TABLE budgets (
  id INTEGER PRIMARY KEY, -- Unique ID for each budget
  user_id INTEGER, -- ID of the user who set the budget
  category VARCHAR(50), -- Category for which budget is set
  amount DECIMAL(10,2) NOT NULL, -- Budget amount
  period VARCHAR(20) DEFAULT 'monthly', -- Budget period (daily, weekly, monthly, yearly)
  start_date DATE, -- Budget start date
  end_date DATE -- Budget end date
);

-- expenses.user_id can be joined with users.id
-- expenses.category can be joined with categories.name
-- budgets.user_id can be joined with users.id
-- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

def generate_query(question):
    if model is None:
        return "Error: Model not loaded properly"
    
    try:
        updated_prompt = prompt_template.format(question=question)
        inputs = tokenizer(updated_prompt, return_tensors="pt")
        
        # Generate on CPU
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                num_return_sequences=1,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id,
                max_new_tokens=400,
                do_sample=False,
                num_beams=1,
                temperature=0.0,
                top_p=1,
            )
        
        outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
        
        # Extract SQL from output
        if "```sql" in outputs[0]:
            sql_part = outputs[0].split("```sql")[1].split("```")[0].strip()
        else:
            # Fallback extraction
            sql_part = outputs[0].split("The following SQL query best answers the question")[1].strip()
            if sql_part.startswith("`"):
                sql_part = sql_part[1:]
            if "```" in sql_part:
                sql_part = sql_part.split("```")[0].strip()
        
        # Clean up the SQL
        if sql_part.endswith(";"):
            sql_part = sql_part[:-1]
        
        # Format the SQL
        formatted_sql = sqlparse.format(sql_part, reindent=True, keyword_case='upper')
        return formatted_sql
        
    except Exception as e:
        return f"Error generating query: {str(e)}"

def gradio_interface(question):
    if not question.strip():
        return "Please enter a question."
    
    return generate_query(question)

# Create Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Textbox(
        label="Question", 
        placeholder="Enter your question (e.g., 'Show me all expenses for food category')",
        lines=3
    ),
    outputs=gr.Code(label="Generated SQL Query", language="sql"),
    title="SQL Query Generator",
    description="Generate SQL queries from natural language questions about expense tracking database.",
    examples=[
        ["Show me all expenses for food category"],
        ["What's the total amount spent on transport this month?"],
        ["Insert a new expense of 50 dollars for groceries on 2024-01-15"],
        ["Find users who spent more than 1000 dollars total"],
        ["Show me the budget vs actual spending for each category"]
    ],
    cache_examples=False
)

if __name__ == "__main__":
    iface.launch()