MHD011 commited on
Commit
7d15cae
·
verified ·
1 Parent(s): ca1d327

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +100 -37
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import torch
8
  import psycopg2
9
  from datetime import datetime
 
10
 
11
  app = Flask(__name__)
12
  CORS(app)
@@ -19,16 +20,33 @@ logger = logging.getLogger(__name__)
19
  MODEL_NAME = "tscholak/3vnuv1vf" # نموذج متخصص لـ PostgreSQL
20
  SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
21
 
22
- # التهيئة المسبقة للنموذج
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
 
25
 
26
- # نقل النموذج إلى الجهاز المناسب
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- model.to(device)
29
- model.eval()
 
 
 
 
 
 
 
 
 
30
 
31
- # سكيما قاعدة البيانات
32
  DB_SCHEMA = """
33
  CREATE TABLE public.biodata (
34
  id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
@@ -110,30 +128,61 @@ CREATE TABLE public.user_place (
110
  );
111
  """.strip()
112
 
113
- def generate_sql(prompt: str, cam_mac: str) -> str:
114
  """
115
- توليد استعلام SQL من النص باستخدام النموذج
116
  """
117
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
 
 
 
 
118
 
119
- with torch.no_grad():
120
- outputs = model.generate(
121
- **inputs,
122
- max_length=256,
123
- num_beams=4,
124
- early_stopping=True,
125
- temperature=0.7
126
- )
127
 
128
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
129
- return clean_sql(sql, cam_mac)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  @app.route('/api/query', methods=['POST'])
132
  def handle_query():
 
 
 
133
  try:
134
  data = request.get_json()
135
 
136
- # التحقق من البيانات المدخلة
137
  if not all(k in data for k in ['text', 'cam_mac']):
138
  return jsonify({"error": "المعطيات ناقصة"}), 400
139
 
@@ -146,31 +195,45 @@ def handle_query():
146
  1. يجب تضمين شرط WHERE: cam_mac = '{data['cam_mac']}'
147
  2. مسموح فقط باستخدام SELECT
148
  3. الجداول المتاحة: profiles, data, place
149
-
150
- المثال:
151
- السؤال: "ما عدد زياراتي؟"
152
- SQL: SELECT COUNT(*) FROM data WHERE cam_mac = '{data['cam_mac']}';
153
  """
154
  # توليد الاستعلام
155
  sql = generate_sql(prompt, data['cam_mac'])
156
 
157
- # التحقق من الأمان
158
- if not is_safe_sql(sql):
159
- sql = f"SELECT * FROM data WHERE cam_mac = '{data['cam_mac']}' LIMIT 10;"
160
-
161
- # تنفيذ الاستعلام
162
- with psycopg2.connect(SUPABASE_DB_URL) as conn:
163
  with conn.cursor() as cursor:
164
  cursor.execute(sql)
165
- columns = [desc[0] for desc in cursor.description]
166
- rows = cursor.fetchall()
 
 
 
 
 
 
167
 
168
  return jsonify({
169
- "data": [dict(zip(columns, row)) for row in rows],
170
  "sql": sql,
171
  "timestamp": datetime.now().isoformat()
172
  })
 
 
 
 
 
 
173
 
174
  except Exception as e:
175
- logger.error(f"Error: {str(e)}")
176
- return jsonify({"error": "حدث خطأ في المعالجة"}), 500
 
 
 
 
 
 
 
 
 
7
  import torch
8
  import psycopg2
9
  from datetime import datetime
10
+ from psycopg2 import pool
11
 
12
  app = Flask(__name__)
13
  CORS(app)
 
20
  MODEL_NAME = "tscholak/3vnuv1vf" # نموذج متخصص لـ PostgreSQL
21
  SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
22
 
23
+ # تهيئة connection pool
24
+ connection_pool = None
25
+ try:
26
+ connection_pool = psycopg2.pool.SimpleConnectionPool(
27
+ minconn=1,
28
+ maxconn=5,
29
+ dsn=SUPABASE_DB_URL
30
+ )
31
+ logger.info("تم إنشاء connection pool بنجاح")
32
+ except Exception as e:
33
+ logger.error(f"خطأ في إنشاء connection pool: {str(e)}")
34
 
35
+ # تحميل النموذج مرة واحدة عند بدء التشغيل
36
+ try:
37
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
38
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
39
+
40
+ # استخدام GPU إذا كان متاحًا، وإلا استخدام CPU
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ model.to(device)
43
+ model.eval()
44
+ logger.info("تم تحميل النموذج بنجاح على الجهاز: %s", device)
45
+ except Exception as e:
46
+ logger.error(f"خطأ في تحميل النموذج: {str(e)}")
47
+ raise
48
 
49
+ # سكيما قاعدة البيانات (مختصرة لتحسين الأداء)
50
  DB_SCHEMA = """
51
  CREATE TABLE public.biodata (
52
  id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
 
128
  );
129
  """.strip()
130
 
131
+ def clean_sql(sql: str, cam_mac: str) -> str:
132
  """
133
+ تنظيف استعلام SQL وإضافة شروط الأمان
134
  """
135
+ # إزالة أي أوامر غير مسموح بها
136
+ forbidden_keywords = ['insert', 'update', 'delete', 'drop', 'alter', 'create', 'truncate']
137
+ for keyword in forbidden_keywords:
138
+ if keyword in sql.lower():
139
+ raise ValueError(f"استعلام غير مسموح به يحتوي على {keyword}")
140
 
141
+ # التأكد من وجود شرط cam_mac
142
+ if 'where' in sql.lower():
143
+ sql = re.sub(r'where\s+', f"WHERE cam_mac = '{cam_mac}' AND ", sql, flags=re.IGNORECASE)
144
+ else:
145
+ if ';' in sql:
146
+ sql = sql.replace(';', f" WHERE cam_mac = '{cam_mac}';")
147
+ else:
148
+ sql += f" WHERE cam_mac = '{cam_mac}'"
149
 
150
+ # التأكد من أن الاستعلام يبدأ بـ SELECT فقط
151
+ if not sql.strip().lower().startswith('select'):
152
+ raise ValueError("يسمح فقط باستعلامات SELECT")
153
+
154
+ return sql
155
+
156
+ def generate_sql(prompt: str, cam_mac: str) -> str:
157
+ """
158
+ توليد استعلام SQL من النص باستخدام النموذج
159
+ """
160
+ try:
161
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
162
+
163
+ with torch.no_grad():
164
+ outputs = model.generate(
165
+ **inputs,
166
+ max_length=256,
167
+ num_beams=4,
168
+ early_stopping=True,
169
+ temperature=0.7
170
+ )
171
+
172
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
173
+ return clean_sql(sql, cam_mac)
174
+ except Exception as e:
175
+ logger.error(f"خطأ في توليد SQL: {str(e)}")
176
+ return f"SELECT * FROM data WHERE cam_mac = '{cam_mac}' LIMIT 10;"
177
 
178
  @app.route('/api/query', methods=['POST'])
179
  def handle_query():
180
+ if not connection_pool:
181
+ return jsonify({"error": "لا يوجد اتصال بقاعدة البيانات"}), 500
182
+
183
  try:
184
  data = request.get_json()
185
 
 
186
  if not all(k in data for k in ['text', 'cam_mac']):
187
  return jsonify({"error": "المعطيات ناقصة"}), 400
188
 
 
195
  1. يجب تضمين شرط WHERE: cam_mac = '{data['cam_mac']}'
196
  2. مسموح فقط باستخدام SELECT
197
  3. الجداول المتاحة: profiles, data, place
 
 
 
 
198
  """
199
  # توليد الاستعلام
200
  sql = generate_sql(prompt, data['cam_mac'])
201
 
202
+ # تنفيذ الاستعلام باستخدام connection pool
203
+ conn = None
204
+ try:
205
+ conn = connection_pool.getconn()
 
 
206
  with conn.cursor() as cursor:
207
  cursor.execute(sql)
208
+
209
+ # إذا كان الاستعلام لا يعيد بيانات (مثل COUNT)
210
+ if cursor.description:
211
+ columns = [desc[0] for desc in cursor.description]
212
+ rows = cursor.fetchall()
213
+ result = [dict(zip(columns, row)) for row in rows]
214
+ else:
215
+ result = {"message": "تم تنفيذ الاستعلام بنجاح"}
216
 
217
  return jsonify({
218
+ "data": result,
219
  "sql": sql,
220
  "timestamp": datetime.now().isoformat()
221
  })
222
+ except Exception as e:
223
+ logger.error(f"خطأ في قاعدة البيانات: {str(e)}")
224
+ return jsonify({"error": "حدث خطأ في معالجة الاستعلام"}), 500
225
+ finally:
226
+ if conn:
227
+ connection_pool.putconn(conn)
228
 
229
  except Exception as e:
230
+ logger.error(f"خطأ عام: {str(e)}")
231
+ return jsonify({"error": "حدث خطأ في المعالجة"}), 500
232
+
233
+ @app.route('/health', methods=['GET'])
234
+ def health_check():
235
+ return jsonify({"status": "healthy", "timestamp": datetime.now().isoformat()})
236
+
237
+ if __name__ == '__main__':
238
+ port = int(os.environ.get('PORT', 8080))
239
+ app.run(host='0.0.0.0', port=port)