Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
7 |
import torch
|
8 |
import psycopg2
|
9 |
from datetime import datetime
|
|
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
CORS(app)
|
@@ -19,16 +20,33 @@ logger = logging.getLogger(__name__)
|
|
19 |
MODEL_NAME = "tscholak/3vnuv1vf" # نموذج متخصص لـ PostgreSQL
|
20 |
SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
#
|
27 |
-
|
28 |
-
|
29 |
-
model.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
# سكيما قاعدة البيانات
|
32 |
DB_SCHEMA = """
|
33 |
CREATE TABLE public.biodata (
|
34 |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
@@ -110,30 +128,61 @@ CREATE TABLE public.user_place (
|
|
110 |
);
|
111 |
""".strip()
|
112 |
|
113 |
-
def
|
114 |
"""
|
115 |
-
|
116 |
"""
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
|
128 |
-
|
129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
@app.route('/api/query', methods=['POST'])
|
132 |
def handle_query():
|
|
|
|
|
|
|
133 |
try:
|
134 |
data = request.get_json()
|
135 |
|
136 |
-
# التحقق من البيانات المدخلة
|
137 |
if not all(k in data for k in ['text', 'cam_mac']):
|
138 |
return jsonify({"error": "المعطيات ناقصة"}), 400
|
139 |
|
@@ -146,31 +195,45 @@ def handle_query():
|
|
146 |
1. يجب تضمين شرط WHERE: cam_mac = '{data['cam_mac']}'
|
147 |
2. مسموح فقط باستخدام SELECT
|
148 |
3. الجداول المتاحة: profiles, data, place
|
149 |
-
|
150 |
-
المثال:
|
151 |
-
السؤال: "ما عدد زياراتي؟"
|
152 |
-
SQL: SELECT COUNT(*) FROM data WHERE cam_mac = '{data['cam_mac']}';
|
153 |
"""
|
154 |
# توليد الاستعلام
|
155 |
sql = generate_sql(prompt, data['cam_mac'])
|
156 |
|
157 |
-
#
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
# تنفيذ الاستعلام
|
162 |
-
with psycopg2.connect(SUPABASE_DB_URL) as conn:
|
163 |
with conn.cursor() as cursor:
|
164 |
cursor.execute(sql)
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
return jsonify({
|
169 |
-
"data":
|
170 |
"sql": sql,
|
171 |
"timestamp": datetime.now().isoformat()
|
172 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
|
174 |
except Exception as e:
|
175 |
-
logger.error(f"
|
176 |
-
return jsonify({"error": "حدث خطأ في المعالجة"}), 500
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import torch
|
8 |
import psycopg2
|
9 |
from datetime import datetime
|
10 |
+
from psycopg2 import pool
|
11 |
|
12 |
app = Flask(__name__)
|
13 |
CORS(app)
|
|
|
20 |
MODEL_NAME = "tscholak/3vnuv1vf" # نموذج متخصص لـ PostgreSQL
|
21 |
SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
|
22 |
|
23 |
+
# تهيئة connection pool
|
24 |
+
connection_pool = None
|
25 |
+
try:
|
26 |
+
connection_pool = psycopg2.pool.SimpleConnectionPool(
|
27 |
+
minconn=1,
|
28 |
+
maxconn=5,
|
29 |
+
dsn=SUPABASE_DB_URL
|
30 |
+
)
|
31 |
+
logger.info("تم إنشاء connection pool بنجاح")
|
32 |
+
except Exception as e:
|
33 |
+
logger.error(f"خطأ في إنشاء connection pool: {str(e)}")
|
34 |
|
35 |
+
# تحميل النموذج مرة واحدة عند بدء التشغيل
|
36 |
+
try:
|
37 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
38 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
|
39 |
+
|
40 |
+
# استخدام GPU إذا كان متاحًا، وإلا استخدام CPU
|
41 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
42 |
+
model.to(device)
|
43 |
+
model.eval()
|
44 |
+
logger.info("تم تحميل النموذج بنجاح على الجهاز: %s", device)
|
45 |
+
except Exception as e:
|
46 |
+
logger.error(f"خطأ في تحميل النموذج: {str(e)}")
|
47 |
+
raise
|
48 |
|
49 |
+
# سكيما قاعدة البيانات (مختصرة لتحسين الأداء)
|
50 |
DB_SCHEMA = """
|
51 |
CREATE TABLE public.biodata (
|
52 |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
|
128 |
);
|
129 |
""".strip()
|
130 |
|
131 |
+
def clean_sql(sql: str, cam_mac: str) -> str:
|
132 |
"""
|
133 |
+
تنظيف استعلام SQL وإضافة شروط الأمان
|
134 |
"""
|
135 |
+
# إزالة أي أوامر غير مسموح بها
|
136 |
+
forbidden_keywords = ['insert', 'update', 'delete', 'drop', 'alter', 'create', 'truncate']
|
137 |
+
for keyword in forbidden_keywords:
|
138 |
+
if keyword in sql.lower():
|
139 |
+
raise ValueError(f"استعلام غير مسموح به يحتوي على {keyword}")
|
140 |
|
141 |
+
# التأكد من وجود شرط cam_mac
|
142 |
+
if 'where' in sql.lower():
|
143 |
+
sql = re.sub(r'where\s+', f"WHERE cam_mac = '{cam_mac}' AND ", sql, flags=re.IGNORECASE)
|
144 |
+
else:
|
145 |
+
if ';' in sql:
|
146 |
+
sql = sql.replace(';', f" WHERE cam_mac = '{cam_mac}';")
|
147 |
+
else:
|
148 |
+
sql += f" WHERE cam_mac = '{cam_mac}'"
|
149 |
|
150 |
+
# التأكد من أن الاستعلام يبدأ بـ SELECT فقط
|
151 |
+
if not sql.strip().lower().startswith('select'):
|
152 |
+
raise ValueError("يسمح فقط باستعلامات SELECT")
|
153 |
+
|
154 |
+
return sql
|
155 |
+
|
156 |
+
def generate_sql(prompt: str, cam_mac: str) -> str:
|
157 |
+
"""
|
158 |
+
توليد استعلام SQL من النص باستخدام النموذج
|
159 |
+
"""
|
160 |
+
try:
|
161 |
+
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
|
162 |
+
|
163 |
+
with torch.no_grad():
|
164 |
+
outputs = model.generate(
|
165 |
+
**inputs,
|
166 |
+
max_length=256,
|
167 |
+
num_beams=4,
|
168 |
+
early_stopping=True,
|
169 |
+
temperature=0.7
|
170 |
+
)
|
171 |
+
|
172 |
+
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
173 |
+
return clean_sql(sql, cam_mac)
|
174 |
+
except Exception as e:
|
175 |
+
logger.error(f"خطأ في توليد SQL: {str(e)}")
|
176 |
+
return f"SELECT * FROM data WHERE cam_mac = '{cam_mac}' LIMIT 10;"
|
177 |
|
178 |
@app.route('/api/query', methods=['POST'])
|
179 |
def handle_query():
|
180 |
+
if not connection_pool:
|
181 |
+
return jsonify({"error": "لا يوجد اتصال بقاعدة البيانات"}), 500
|
182 |
+
|
183 |
try:
|
184 |
data = request.get_json()
|
185 |
|
|
|
186 |
if not all(k in data for k in ['text', 'cam_mac']):
|
187 |
return jsonify({"error": "المعطيات ناقصة"}), 400
|
188 |
|
|
|
195 |
1. يجب تضمين شرط WHERE: cam_mac = '{data['cam_mac']}'
|
196 |
2. مسموح فقط باستخدام SELECT
|
197 |
3. الجداول المتاحة: profiles, data, place
|
|
|
|
|
|
|
|
|
198 |
"""
|
199 |
# توليد الاستعلام
|
200 |
sql = generate_sql(prompt, data['cam_mac'])
|
201 |
|
202 |
+
# تنفيذ الاستعلام باستخدام connection pool
|
203 |
+
conn = None
|
204 |
+
try:
|
205 |
+
conn = connection_pool.getconn()
|
|
|
|
|
206 |
with conn.cursor() as cursor:
|
207 |
cursor.execute(sql)
|
208 |
+
|
209 |
+
# إذا كان الاستعلام لا يعيد بيانات (مثل COUNT)
|
210 |
+
if cursor.description:
|
211 |
+
columns = [desc[0] for desc in cursor.description]
|
212 |
+
rows = cursor.fetchall()
|
213 |
+
result = [dict(zip(columns, row)) for row in rows]
|
214 |
+
else:
|
215 |
+
result = {"message": "تم تنفيذ الاستعلام بنجاح"}
|
216 |
|
217 |
return jsonify({
|
218 |
+
"data": result,
|
219 |
"sql": sql,
|
220 |
"timestamp": datetime.now().isoformat()
|
221 |
})
|
222 |
+
except Exception as e:
|
223 |
+
logger.error(f"خطأ في قاعدة البيانات: {str(e)}")
|
224 |
+
return jsonify({"error": "حدث خطأ في معالجة الاستعلام"}), 500
|
225 |
+
finally:
|
226 |
+
if conn:
|
227 |
+
connection_pool.putconn(conn)
|
228 |
|
229 |
except Exception as e:
|
230 |
+
logger.error(f"خطأ عام: {str(e)}")
|
231 |
+
return jsonify({"error": "حدث خطأ في المعالجة"}), 500
|
232 |
+
|
233 |
+
@app.route('/health', methods=['GET'])
|
234 |
+
def health_check():
|
235 |
+
return jsonify({"status": "healthy", "timestamp": datetime.now().isoformat()})
|
236 |
+
|
237 |
+
if __name__ == '__main__':
|
238 |
+
port = int(os.environ.get('PORT', 8080))
|
239 |
+
app.run(host='0.0.0.0', port=port)
|