Roxanne-WANG commited on
Commit
13ee483
·
1 Parent(s): a545146

organize text2sql

Browse files
Files changed (2) hide show
  1. data/history/history.sqlite +0 -0
  2. text2sql.py +221 -147
data/history/history.sqlite CHANGED
Binary files a/data/history/history.sqlite and b/data/history/history.sqlite differ
 
text2sql.py CHANGED
@@ -1,180 +1,254 @@
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
- import torch
4
- import copy
5
  import re
6
- import sqlparse
7
  import sqlite3
8
-
9
  from tqdm import tqdm
10
- from utils.db_utils import get_db_schema
 
11
  from transformers import AutoModelForCausalLM, AutoTokenizer
12
  from whoosh import index
13
- from whoosh.index import create_in
14
- from whoosh.fields import Schema, TEXT
15
- from whoosh.qparser import QueryParser
16
- from utils.db_utils import check_sql_executability, get_matched_contents, get_db_schema_sequence, get_matched_content_sequence
 
 
 
 
17
  from schema_item_filter import SchemaItemClassifierInference, filter_schema
18
 
19
- def remove_similar_comments(names, comments):
20
- '''
21
- Remove table (or column) comments that have a high degree of similarity with their names
22
- '''
23
- new_comments = []
24
- for name, comment in zip(names, comments):
25
- if name.replace("_", "").replace(" ", "") == comment.replace("_", "").replace(" ", ""):
26
- new_comments.append("")
27
- else:
28
- new_comments.append(comment)
29
-
30
- return new_comments
31
-
32
- def load_db_comments(table_json_path):
33
- additional_db_info = json.load(open(table_json_path))
34
- db_comments = dict()
35
- for db_info in additional_db_info:
36
- comment_dict = dict()
37
-
38
- column_names = [column_name.lower() for _, column_name in db_info["column_names_original"]]
39
- table_idx_of_each_column = [t_idx for t_idx, _ in db_info["column_names_original"]]
40
- column_comments = [column_comment.lower() for _, column_comment in db_info["column_names"]]
41
-
42
- assert len(column_names) == len(column_comments)
43
- column_comments = remove_similar_comments(column_names, column_comments)
44
-
45
- table_names = [table_name.lower() for table_name in db_info["table_names_original"]]
46
- table_comments = [table_comment.lower() for table_comment in db_info["table_names"]]
47
-
48
- assert len(table_names) == len(table_comments)
49
- table_comments = remove_similar_comments(table_names, table_comments)
50
-
51
- for table_idx, (table_name, table_comment) in enumerate(zip(table_names, table_comments)):
52
- comment_dict[table_name] = {
53
- "table_comment": table_comment,
54
- "column_comments": dict()
55
- }
56
- for t_idx, column_name, column_comment in zip(table_idx_of_each_column, column_names, column_comments):
57
- if t_idx == table_idx:
58
- comment_dict[table_name]["column_comments"][column_name] = column_comment
59
-
60
- db_comments[db_info["db_id"]] = comment_dict
61
-
62
- return db_comments
63
-
64
- def get_db_id2schema(db_path, tables_json):
65
- db_comments = load_db_comments(tables_json)
66
- db_id2schema = dict()
67
-
68
- for db_id in tqdm(os.listdir(db_path)):
69
- db_id2schema[db_id] = get_db_schema(os.path.join(db_path, db_id, db_id + ".sqlite"), db_comments, db_id)
70
-
71
- return db_id2schema
72
-
73
- def get_db_id2ddl(db_path):
74
- db_ids = os.listdir(db_path)
75
- db_id2ddl = dict()
76
-
77
- for db_id in db_ids:
78
- conn = sqlite3.connect(os.path.join(db_path, db_id, db_id + ".sqlite"))
79
- cursor = conn.cursor()
80
- cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
81
- tables = cursor.fetchall()
82
- ddl = []
83
-
84
- for table in tables:
85
- table_name = table[0]
86
- table_ddl = table[1]
87
- table_ddl.replace("\t", " ")
88
- while " " in table_ddl:
89
- table_ddl = table_ddl.replace(" ", " ")
90
-
91
- table_ddl = re.sub(r'--.*', '', table_ddl)
92
- table_ddl = sqlparse.format(table_ddl, keyword_case = "upper", identifier_case = "lower", reindent_aligned = True)
93
- table_ddl = table_ddl.replace(", ", ",\n ")
94
-
95
- if table_ddl.endswith(";"):
96
- table_ddl = table_ddl[:-1]
97
- table_ddl = table_ddl[:-1] + "\n);"
98
- table_ddl = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", table_ddl)
99
-
100
- ddl.append(table_ddl)
101
- db_id2ddl[db_id] = "\n\n".join(ddl)
102
-
103
- return db_id2ddl
104
-
105
- class ChatBot():
106
- def __init__(self) -> None:
107
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108
- model_name = "seeklhy/codes-1b"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
110
- self.model = AutoModelForCausalLM.from_pretrained(model_name, device_map = "auto", torch_dtype = torch.float16)
 
 
 
 
111
  self.max_length = 4096
112
  self.max_new_tokens = 256
113
  self.max_prefix_length = self.max_length - self.max_new_tokens
114
 
115
- # Directly loading the model from Hugging Face
116
- self.sic = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
117
- self.db_id2content_searcher = dict()
118
- for db_id in os.listdir("db_contents_index"):
119
- index_dir = os.path.join("db_contents_index", db_id)
120
-
121
- # Open existing Whoosh index directory
122
- if index.exists_in(index_dir):
123
- ix = index.open_dir(index_dir)
124
- # keep a searcher around for querying
125
- self.db_id2content_searcher[db_id] = ix.searcher()
126
  else:
127
- raise ValueError(f"No Whoosh index found for '{db_id}' at '{index_dir}'")
128
 
 
129
  self.db_ids = sorted(os.listdir("databases"))
130
- self.db_id2schema = get_db_id2schema("databases", "data/tables.json")
131
- self.db_id2ddl = get_db_id2ddl("databases")
132
-
133
- def get_response(self, question, db_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  data = {
135
  "text": question,
136
- "schema": copy.deepcopy(self.db_id2schema[db_id]),
137
- "matched_contents": get_matched_contents(question, self.db_id2content_searcher[db_id])
138
  }
139
- data = filter_schema(data, self.sic, 6, 10)
140
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
141
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
142
-
143
- prefix_seq = data["schema_sequence"] + "\n" + data["content_sequence"] + "\n" + data["text"] + "\n"
144
- print(prefix_seq)
145
-
146
- input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix_seq , truncation = False)["input_ids"]
 
 
 
 
147
  if len(input_ids) > self.max_prefix_length:
148
- print("the current input sequence exceeds the max_tokens, we will truncate it.")
149
- input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length-1):]
150
  attention_mask = [1] * len(input_ids)
151
-
152
  inputs = {
153
- "input_ids": torch.tensor([input_ids], dtype = torch.int64).to(self.model.device),
154
- "attention_mask": torch.tensor([attention_mask], dtype = torch.int64).to(self.model.device)
155
  }
156
- input_length = inputs["input_ids"].shape[1]
157
 
158
  with torch.no_grad():
159
- generate_ids = self.model.generate(
160
  **inputs,
161
- max_new_tokens = self.max_new_tokens,
162
- num_beams = 4,
163
- num_return_sequences = 4
164
  )
165
 
166
- generated_sqls = self.tokenizer.batch_decode(generate_ids[:, input_length:], skip_special_tokens = True, clean_up_tokenization_spaces = False)
167
- final_generated_sql = None
168
- for generated_sql in generated_sqls:
169
- execution_error = check_sql_executability(generated_sql, os.path.join("databases", db_id, db_id + ".sqlite"))
170
- if execution_error is None:
171
- final_generated_sql = generated_sql
 
 
 
 
172
  break
 
 
173
 
174
- if final_generated_sql is None:
175
- if generated_sqls[0].strip() != "":
176
- final_generated_sql = generated_sqls[0].strip()
177
- else:
178
- final_generated_sql = "Sorry, I can not generate a suitable SQL query for your question."
179
-
180
- return final_generated_sql.replace("\n", " ")
 
1
+ # Attribution: Original code by Ruoxin Wang
2
+ # Repository: <your-repo-url>
3
+
4
+ """
5
+ Module: refactored_chatbot
6
+ This module provides utilities for loading database schemas, extracting DDL,
7
+ indexing content, and a ChatBot class to generate SQL queries from natural language.
8
+ """
9
  import os
10
  import json
 
 
11
  import re
 
12
  import sqlite3
13
+ import copy
14
  from tqdm import tqdm
15
+
16
+ import torch
17
  from transformers import AutoModelForCausalLM, AutoTokenizer
18
  from whoosh import index
19
+
20
+ from utils.db_utils import (
21
+ get_db_schema,
22
+ check_sql_executability,
23
+ get_matched_contents,
24
+ get_db_schema_sequence,
25
+ get_matched_content_sequence
26
+ )
27
  from schema_item_filter import SchemaItemClassifierInference, filter_schema
28
 
29
+
30
+ class DatabaseUtils:
31
+ """
32
+ Utilities for loading database comments, schemas, and DDL statements.
33
+ """
34
+
35
+ @staticmethod
36
+ def _remove_similar_comments(names, comments):
37
+ """
38
+ Remove comments identical to table/column names (ignoring underscores and spaces).
39
+ """
40
+ filtered = []
41
+ for name, comment in zip(names, comments):
42
+ normalized_name = name.replace("_", "").replace(" ", "").lower()
43
+ normalized_comment = comment.replace("_", "").replace(" ", "").lower()
44
+ filtered.append("") if normalized_name == normalized_comment else filtered.append(comment)
45
+ return filtered
46
+
47
+ @staticmethod
48
+ def load_db_comments(table_json_path):
49
+ """
50
+ Load additional comments for tables and columns from a JSON file.
51
+
52
+ Args:
53
+ table_json_path (str): Path to JSON file containing table and column comments.
54
+
55
+ Returns:
56
+ dict: Mapping from database ID to comments structure.
57
+ """
58
+ additional_info = json.load(open(table_json_path))
59
+ db_comments = {}
60
+
61
+ for db_info in additional_info:
62
+ db_id = db_info["db_id"]
63
+ comment_dict = {}
64
+
65
+ # Process column comments
66
+ original_cols = db_info["column_names_original"]
67
+ col_names = [col.lower() for _, col in original_cols]
68
+ col_comments = [c.lower() for _, c in db_info["column_names"]]
69
+ col_comments = DatabaseUtils._remove_similar_comments(col_names, col_comments)
70
+ col_table_idxs = [t_idx for t_idx, _ in original_cols]
71
+
72
+ # Process table comments
73
+ original_tables = db_info["table_names_original"]
74
+ tbl_names = [tbl.lower() for tbl in original_tables]
75
+ tbl_comments = [c.lower() for c in db_info["table_names"]]
76
+ tbl_comments = DatabaseUtils._remove_similar_comments(tbl_names, tbl_comments)
77
+
78
+ for idx, name in enumerate(tbl_names):
79
+ comment_dict[name] = {
80
+ "table_comment": tbl_comments[idx],
81
+ "column_comments": {}
82
+ }
83
+ # Associate columns
84
+ for t_idx, col_name, col_comment in zip(col_table_idxs, col_names, col_comments):
85
+ if t_idx == idx:
86
+ comment_dict[name]["column_comments"][col_name] = col_comment
87
+
88
+ db_comments[db_id] = comment_dict
89
+
90
+ return db_comments
91
+
92
+ @staticmethod
93
+ def get_db_schemas(db_path, tables_json):
94
+ """
95
+ Build a mapping from database ID to its schema representation.
96
+
97
+ Args:
98
+ db_path (str): Directory containing database subdirectories.
99
+ tables_json (str): Path to JSON with table comments.
100
+
101
+ Returns:
102
+ dict: Mapping from db_id to schema object.
103
+ """
104
+ comments = DatabaseUtils.load_db_comments(tables_json)
105
+ schemas = {}
106
+ for db_id in tqdm(os.listdir(db_path), desc="Loading schemas"):
107
+ sqlite_path = os.path.join(db_path, db_id, f"{db_id}.sqlite")
108
+ schemas[db_id] = get_db_schema(sqlite_path, comments, db_id)
109
+ return schemas
110
+
111
+ @staticmethod
112
+ def get_db_ddls(db_path):
113
+ """
114
+ Extract formatted DDL statements for all tables in each database.
115
+
116
+ Args:
117
+ db_path (str): Directory containing database subdirectories.
118
+
119
+ Returns:
120
+ dict: Mapping from db_id to its DDL string.
121
+ """
122
+ ddls = {}
123
+ for db_id in os.listdir(db_path):
124
+ conn = sqlite3.connect(os.path.join(db_path, db_id, f"{db_id}.sqlite"))
125
+ cursor = conn.cursor()
126
+ cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table';")
127
+ ddl_statements = []
128
+
129
+ for name, raw_sql in cursor.fetchall():
130
+ sql = raw_sql or ""
131
+ sql = re.sub(r'--.*', '', sql).replace("\t", " ")
132
+ sql = re.sub(r" +", " ", sql)
133
+ formatted = sqlparse.format(
134
+ sql,
135
+ keyword_case="upper",
136
+ identifier_case="lower",
137
+ reindent_aligned=True
138
+ )
139
+ # Adjust spacing for readability
140
+ formatted = formatted.replace(", ", ",\n ")
141
+ if formatted.rstrip().endswith(";"):
142
+ formatted = formatted.rstrip()[:-1] + "\n);"
143
+ formatted = re.sub(r"(CREATE TABLE.*?)\(", r"\1(\n ", formatted)
144
+ ddl_statements.append(formatted)
145
+
146
+ ddls[db_id] = "\n\n".join(ddl_statements)
147
+ return ddls
148
+
149
+
150
+ class ChatBot:
151
+ """
152
+ ChatBot for generating and executing SQL queries using a causal language model.
153
+ """
154
+
155
+ def __init__(self, model_name: str = "seeklhy/codes-1b", device: str = "cuda:0") -> None:
156
+ """
157
+ Initialize the ChatBot with model and tokenizer.
158
+
159
+ Args:
160
+ model_name (str): HuggingFace model identifier.
161
+ device (str): CUDA device string or 'cpu'.
162
+ """
163
+ os.environ["CUDA_VISIBLE_DEVICES"] = device
164
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
165
+ self.model = AutoModelForCausalLM.from_pretrained(
166
+ model_name,
167
+ device_map="auto",
168
+ torch_dtype=torch.float16
169
+ )
170
  self.max_length = 4096
171
  self.max_new_tokens = 256
172
  self.max_prefix_length = self.max_length - self.max_new_tokens
173
 
174
+ # Schema item classifier
175
+ self.schema_classifier = SchemaItemClassifierInference("Roxanne-WANG/LangSQL")
176
+
177
+ # Initialize content searchers
178
+ self.content_searchers = {}
179
+ index_dir = "db_contents_index"
180
+ for db_id in os.listdir(index_dir):
181
+ path = os.path.join(index_dir, db_id)
182
+ if index.exists_in(path):
183
+ self.content_searchers[db_id] = index.open_dir(path).searcher()
 
184
  else:
185
+ raise FileNotFoundError(f"Whoosh index not found for '{db_id}' at '{path}'")
186
 
187
+ # Load schemas and DDLs
188
  self.db_ids = sorted(os.listdir("databases"))
189
+ self.schemas = DatabaseUtils.get_db_schemas("databases", "data/tables.json")
190
+ self.ddls = DatabaseUtils.get_db_ddls("databases")
191
+
192
+ def get_response(self, question: str, db_id: str) -> str:
193
+ """
194
+ Generate an executable SQL query for a natural language question.
195
+
196
+ Args:
197
+ question (str): User question in natural language.
198
+ db_id (str): Identifier of the target database.
199
+
200
+ Returns:
201
+ str: Executable SQL query or an error message.
202
+ """
203
+ # Prepare data
204
+ schema = copy.deepcopy(self.schemas[db_id])
205
+ contents = get_matched_contents(question, self.content_searchers[db_id])
206
  data = {
207
  "text": question,
208
+ "schema": schema,
209
+ "matched_contents": contents
210
  }
211
+ data = filter_schema(data, self.schema_classifier, top_k=6, top_m=10)
212
  data["schema_sequence"] = get_db_schema_sequence(data["schema"])
213
  data["content_sequence"] = get_matched_content_sequence(data["matched_contents"])
214
+
215
+ prefix = (
216
+ f"{data['schema_sequence']}\n"
217
+ f"{data['content_sequence']}\n"
218
+ f"{question}\n"
219
+ )
220
+
221
+ # Tokenize and ensure length limits
222
+ input_ids = [self.tokenizer.bos_token_id] + self.tokenizer(prefix)["input_ids"]
223
  if len(input_ids) > self.max_prefix_length:
224
+ input_ids = [self.tokenizer.bos_token_id] + input_ids[-(self.max_prefix_length - 1):]
 
225
  attention_mask = [1] * len(input_ids)
226
+
227
  inputs = {
228
+ "input_ids": torch.tensor([input_ids], dtype=torch.int64).to(self.model.device),
229
+ "attention_mask": torch.tensor([attention_mask], dtype=torch.int64).to(self.model.device)
230
  }
 
231
 
232
  with torch.no_grad():
233
+ outputs = self.model.generate(
234
  **inputs,
235
+ max_new_tokens=self.max_new_tokens,
236
+ num_beams=4,
237
+ num_return_sequences=4
238
  )
239
 
240
+ # Decode and choose executable SQL
241
+ decoded = self.tokenizer.batch_decode(
242
+ outputs[:, inputs['input_ids'].shape[1]:],
243
+ skip_special_tokens=True,
244
+ clean_up_tokenization_spaces=False
245
+ )
246
+ final_sql = None
247
+ for sql in decoded:
248
+ if check_sql_executability(sql, os.path.join("databases", db_id, f"{db_id}.sqlite")) is None:
249
+ final_sql = sql.strip()
250
  break
251
+ if not final_sql:
252
+ final_sql = decoded[0].strip() or "Sorry, I cannot generate a suitable SQL query."
253
 
254
+ return final_sql