MHD011 commited on
Commit
2b295d5
·
verified ·
1 Parent(s): eeb21cb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +97 -15
  2. requirements.txt +5 -7
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
 
2
  import logging
3
- from flask import Flask, request, jsonify
4
  from flask_cors import CORS
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import torch
 
 
7
 
8
  app = Flask(__name__)
9
  CORS(app)
@@ -14,6 +17,7 @@ logger = logging.getLogger(__name__)
14
 
15
  # --- إعداد النموذج ---
16
  MODEL_NAME = "tscholak/cxmefzzi" # نموذج Text-to-SQL بديل
 
17
 
18
  tokenizer = None
19
  model = None
@@ -125,26 +129,63 @@ CREATE TABLE public.flag (
125
  );
126
  """.strip()
127
 
128
- @app.route('/generate-sql', methods=['POST'])
129
- def generate_sql():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  if tokenizer is None or model is None:
131
  return jsonify({"error": "النموذج غير محمل، يرجى المحاولة لاحقاً"}), 503
132
 
133
  try:
134
- body = request.get_json()
135
- user_text = body.get("text", "").strip()
136
- cam_mac = body.get("cam_mac", "").strip()
137
-
138
- if not user_text or not cam_mac:
139
  return jsonify({"error": "يرجى إرسال 'text' و 'cam_mac'"}), 400
140
 
 
 
 
 
 
 
 
 
141
  prompt = f"""
142
  ### Postgres SQL table definitions
143
  {DB_SCHEMA}
144
 
145
- ### User question: {user_text}
 
 
 
 
 
 
 
146
 
147
- ### SQL query to answer the question filtered by cam_mac = '{cam_mac}':
148
  SELECT
149
  """.strip()
150
 
@@ -153,12 +194,53 @@ def generate_sql():
153
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
154
 
155
  # تنظيف الناتج
156
- if not sql.lower().startswith("select"):
157
- sql = "SELECT " + sql
 
 
 
 
 
158
  if not sql.endswith(";"):
159
  sql += ";"
160
 
161
- return jsonify({"sql": sql})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
 
163
  except Exception as e:
164
  logger.error(f"خطأ في التوليد: {str(e)}")
@@ -168,8 +250,8 @@ def generate_sql():
168
  def home():
169
  return """
170
  <h1>Text2SQL API</h1>
171
- <p>Send a POST request to <code>/generate-sql</code> with JSON: {"text": "سؤالك", "cam_mac": "عنوان MAC"}</p>
172
  """
173
 
174
  if __name__ == '__main__':
175
- app.run()
 
1
  import os
2
+ import re
3
  import logging
4
+ from flask import Flask, request, jsonify, Response
5
  from flask_cors import CORS
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import torch
8
+ import psycopg2
9
+ import json
10
 
11
  app = Flask(__name__)
12
  CORS(app)
 
17
 
18
  # --- إعداد النموذج ---
19
  MODEL_NAME = "tscholak/cxmefzzi" # نموذج Text-to-SQL بديل
20
+ SUPABASE_DB_URL = "postgresql://postgres.mougnkvoyyhcuxeeqvmh:Xf5E0DhUvKEHEAqq@aws-0-eu-central-1.pooler.supabase.com:6543/postgres"
21
 
22
  tokenizer = None
23
  model = None
 
129
  );
130
  """.strip()
131
 
132
+ # --- الاتصال بقاعدة البيانات ---
133
+ def get_db_connection():
134
+ try:
135
+ return psycopg2.connect(SUPABASE_DB_URL)
136
+ except Exception as err:
137
+ logger.error(f"Database connection error: {err}")
138
+ return None
139
+
140
+ # --- التحقق من صحة cam_mac ---
141
+ def validate_cam_mac(cam_mac):
142
+ conn = get_db_connection()
143
+ if not conn:
144
+ return False
145
+
146
+ try:
147
+ cursor = conn.cursor()
148
+ cursor.execute("SELECT 1 FROM profiles WHERE cam_mac = %s;", (cam_mac,))
149
+ return cursor.fetchone() is not None
150
+ except Exception as e:
151
+ logger.error(f"Validation error: {e}")
152
+ return False
153
+ finally:
154
+ if conn:
155
+ conn.close()
156
+
157
+ @app.route('/api/query', methods=['POST'])
158
+ def handle_query():
159
  if tokenizer is None or model is None:
160
  return jsonify({"error": "النموذج غير محمل، يرجى المحاولة لاحقاً"}), 503
161
 
162
  try:
163
+ data = request.get_json()
164
+ if not data or 'text' not in data or 'cam_mac' not in data:
 
 
 
165
  return jsonify({"error": "يرجى إرسال 'text' و 'cam_mac'"}), 400
166
 
167
+ natural_query = data['text']
168
+ cam_mac = data['cam_mac']
169
+ logger.info(f"استعلام من {cam_mac}: {natural_query}")
170
+
171
+ # التحقق من صحة cam_mac
172
+ if not validate_cam_mac(cam_mac):
173
+ return jsonify({"error": "عنوان MAC غير صالح"}), 403
174
+
175
  prompt = f"""
176
  ### Postgres SQL table definitions
177
  {DB_SCHEMA}
178
 
179
+ ### Rules:
180
+ - Always filter by cam_mac = '{cam_mac}'
181
+ - Use only SELECT statements
182
+ - Use proper JOINs
183
+ - Use table aliases when helpful
184
+ - The output must contain only the SQL query
185
+
186
+ ### User question: {natural_query}
187
 
188
+ ### SQL query:
189
  SELECT
190
  """.strip()
191
 
 
194
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
195
 
196
  # تنظيف الناتج
197
+ sql = re.sub(r"^```sql\s*", "", sql, flags=re.IGNORECASE)
198
+ sql = re.sub(r"\s*```$", "", sql)
199
+ sql = re.sub(r"^SQL:\s*", "", sql, flags=re.IGNORECASE)
200
+
201
+ if not sql.upper().startswith("SELECT"):
202
+ sql = "SELECT " + sql.split("SELECT")[-1] if "SELECT" in sql else f"SELECT * FROM ({sql}) AS subquery"
203
+
204
  if not sql.endswith(";"):
205
  sql += ";"
206
 
207
+ logger.info(f"استعلام SQL المولد: {sql}")
208
+
209
+ if not sql.upper().strip().startswith("SELECT"):
210
+ return jsonify({"error": "يُسمح فقط باستعلامات SELECT"}), 403
211
+
212
+ conn = get_db_connection()
213
+ if not conn:
214
+ return jsonify({"error": "فشل الاتصال بقاعدة البيانات"}), 500
215
+
216
+ cursor = None
217
+ try:
218
+ cursor = conn.cursor()
219
+ cursor.execute(sql)
220
+ columns = [desc[0] for desc in cursor.description]
221
+ rows = cursor.fetchall()
222
+ data = [dict(zip(columns, row)) for row in rows]
223
+
224
+ response_data = {
225
+ "data": data,
226
+ }
227
+
228
+ response_json = json.dumps(response_data, ensure_ascii=False)
229
+
230
+ return Response(
231
+ response_json,
232
+ status=200,
233
+ mimetype='application/json; charset=utf-8'
234
+ )
235
+
236
+ except Exception as e:
237
+ logger.error(f"خطأ في تنفيذ SQL: {e}")
238
+ return jsonify({"error": str(e), "generated_sql": sql}), 500
239
+ finally:
240
+ if cursor:
241
+ cursor.close()
242
+ if conn:
243
+ conn.close()
244
 
245
  except Exception as e:
246
  logger.error(f"خطأ في التوليد: {str(e)}")
 
250
  def home():
251
  return """
252
  <h1>Text2SQL API</h1>
253
+ <p>Use <code>/api/query</code> with POST {"text": "your question", "cam_mac": "device_mac_address"}.</p>
254
  """
255
 
256
  if __name__ == '__main__':
257
+ app.run(host='0.0.0.0', port=7860)
requirements.txt CHANGED
@@ -1,7 +1,5 @@
1
- flask>=2.0.0
2
- flask-cors>=3.0.0
3
- torch>=2.0.0
4
- transformers>=4.30.0
5
- accelerate
6
- huggingface-hub>=0.15.0
7
- gunicorn>=20.1.0
 
1
+ flask==2.3.2
2
+ flask-cors==3.0.10
3
+ torch==2.0.1
4
+ transformers==4.30.2
5
+ psycopg2-binary==2.9.7