|
import os |
|
import re |
|
import logging |
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import psycopg2 |
|
from datetime import datetime |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_NAME = "tscholak/3vnuv1vf" |
|
SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL') |
|
|
|
|
|
os.makedirs("model_cache", exist_ok=True) |
|
os.environ["TRANSFORMERS_CACHE"] = "model_cache" |
|
os.environ["TORCH_CACHE"] = "model_cache" |
|
|
|
tokenizer = None |
|
model = None |
|
|
|
def initialize(): |
|
global tokenizer, model |
|
try: |
|
logger.info("🚀 جاري تحميل النموذج...") |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
MODEL_NAME, |
|
cache_dir="model_cache" |
|
) |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained( |
|
MODEL_NAME, |
|
cache_dir="model_cache", |
|
device_map="auto", |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
model.to('cuda') |
|
logger.info("✅ تم تحميل النموذج على GPU") |
|
else: |
|
model.float() |
|
logger.info("✅ تم تحميل النموذج على CPU (باستخدام float32)") |
|
|
|
model.eval() |
|
except Exception as e: |
|
logger.error(f"❌ فشل في تحميل النموذج: {str(e)}") |
|
raise |
|
|
|
initialize() |
|
|
|
|
|
DB_SCHEMA = """ |
|
CREATE TABLE public.profiles ( |
|
id uuid NOT NULL, |
|
updated_at timestamp with time zone, |
|
username text UNIQUE CHECK (char_length(username) >= 3), |
|
full_name text, |
|
avatar_url text, |
|
website text, |
|
cam_mac text UNIQUE, |
|
fcm_token text, |
|
notification_enabled boolean DEFAULT true, |
|
CONSTRAINT profiles_pkey PRIMARY KEY (id), |
|
CONSTRAINT profiles_id_fkey FOREIGN KEY (id) REFERENCES auth.users(id) |
|
); |
|
|
|
CREATE TABLE public.place ( |
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, |
|
created_at timestamp with time zone DEFAULT (now() AT TIME ZONE 'utc'::text), |
|
name text, |
|
CONSTRAINT place_pkey PRIMARY KEY (id) |
|
); |
|
|
|
CREATE TABLE public.user_place ( |
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, |
|
created_at timestamp with time zone NOT NULL DEFAULT now(), |
|
place_id bigint, |
|
user_cam_mac text, |
|
CONSTRAINT user_place_pkey PRIMARY KEY (id), |
|
CONSTRAINT user_place_place_id_fkey FOREIGN KEY (place_id) REFERENCES public.place(id), |
|
CONSTRAINT user_place_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac) |
|
); |
|
|
|
CREATE TABLE public.data ( |
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, |
|
created_at timestamp without time zone, |
|
caption text, |
|
image_url text, |
|
latitude double precision DEFAULT '36.1833854'::double precision, |
|
longitude double precision DEFAULT '37.1309255'::double precision, |
|
user_place_id bigint, |
|
cam_mac text, |
|
CONSTRAINT data_pkey PRIMARY KEY (id), |
|
CONSTRAINT data_user_place_id_fkey FOREIGN KEY (user_place_id) REFERENCES public.user_place(id) |
|
); |
|
|
|
CREATE TABLE public.biodata ( |
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, |
|
created_at timestamp with time zone NOT NULL DEFAULT now(), |
|
mac_address text, |
|
acceleration_x double precision, |
|
acceleration_y double precision, |
|
acceleration_z double precision, |
|
gyro_x double precision, |
|
gyro_y double precision, |
|
gyro_z double precision, |
|
temperature double precision, |
|
CONSTRAINT biodata_pkey PRIMARY KEY (id), |
|
CONSTRAINT biodata_mac_address_fkey FOREIGN KEY (mac_address) REFERENCES public.profiles(cam_mac) |
|
); |
|
|
|
CREATE TABLE public.notification ( |
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, |
|
created_at timestamp without time zone NOT NULL DEFAULT now(), |
|
user_cam_mac text, |
|
title text, |
|
message text, |
|
is_read boolean, |
|
acceleration_x double precision, |
|
acceleration_y double precision, |
|
acceleration_z double precision, |
|
gyro_x double precision, |
|
gyro_y double precision, |
|
gyro_z double precision, |
|
CONSTRAINT notification_pkey PRIMARY KEY (id), |
|
CONSTRAINT notification_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac) |
|
); |
|
|
|
CREATE TABLE public.flag ( |
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, |
|
flag smallint, |
|
user_mac_address text, |
|
CONSTRAINT flag_pkey PRIMARY KEY (id), |
|
CONSTRAINT flag_user_mac_address_fkey FOREIGN KEY (user_mac_address) REFERENCES public.profiles(cam_mac) |
|
); |
|
""".strip() |
|
|
|
def get_db_connection(): |
|
try: |
|
conn = psycopg2.connect(SUPABASE_DB_URL) |
|
logger.info("✅ تم الاتصال بقاعدة البيانات") |
|
return conn |
|
except Exception as err: |
|
logger.error(f"❌ خطأ في الاتصال بقاعدة البيانات: {err}") |
|
return None |
|
|
|
def validate_cam_mac(cam_mac): |
|
conn = get_db_connection() |
|
if not conn: |
|
return False |
|
|
|
try: |
|
cursor = conn.cursor() |
|
cursor.execute("SELECT 1 FROM profiles WHERE cam_mac = %s;", (cam_mac,)) |
|
return cursor.fetchone() is not None |
|
except Exception as e: |
|
logger.error(f"❌ خطأ في التحقق من cam_mac: {e}") |
|
return False |
|
finally: |
|
if conn: |
|
conn.close() |
|
|
|
def is_safe_sql(sql): |
|
"""تحقق مما إذا كان الاستعلام SELECT فقط وآمن للتنفيذ""" |
|
sql_upper = sql.upper() |
|
|
|
|
|
forbidden_commands = [ |
|
"INSERT", "UPDATE", "DELETE", "CREATE", |
|
"DROP", "ALTER", "TRUNCATE", "GRANT", |
|
"REVOKE", "COMMIT", "ROLLBACK" |
|
] |
|
|
|
|
|
if not sql_upper.strip().startswith("SELECT"): |
|
return False |
|
|
|
|
|
for cmd in forbidden_commands: |
|
if cmd in sql_upper: |
|
return False |
|
|
|
return True |
|
|
|
def clean_sql(sql): |
|
"""تنظيف استعلام SQL لضمان أنه SELECT فقط""" |
|
|
|
sql = re.sub(r'^[^S]*(SELECT)', 'SELECT', sql, flags=re.IGNORECASE) |
|
|
|
|
|
sql = sql.split(';')[0] + ';' |
|
|
|
|
|
sql = ' '.join(sql.split()).strip() |
|
|
|
return sql |
|
|
|
@app.route('/api/query', methods=['POST']) |
|
def handle_query(): |
|
try: |
|
data = request.get_json() |
|
logger.info(f"📩 بيانات الطلب: {data}") |
|
|
|
if not data or 'text' not in data or 'cam_mac' not in data: |
|
return jsonify({"error": "يرجى إرسال 'text' و 'cam_mac'"}), 400 |
|
|
|
if not validate_cam_mac(data['cam_mac']): |
|
return jsonify({"error": "عنوان MAC غير صالح"}), 403 |
|
|
|
prompt = f""" |
|
### التعليمات الصارمة: |
|
1. قم بتحويل السؤال إلى استعلام SELECT فقط لـ PostgreSQL. |
|
2. يجب أن يتضمن الشرط: WHERE cam_mac = '{data['cam_mac']}'. |
|
3. ممنوع تمامًا استخدام أي أوامر غير SELECT. |
|
|
|
### هيكل قاعدة البيانات: |
|
{DB_SCHEMA} |
|
|
|
### السؤال: |
|
{data['text']} |
|
|
|
### استعلام SQL (SELECT فقط): |
|
""" |
|
|
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
inputs = inputs.to(torch.float32) |
|
|
|
if torch.cuda.is_available(): |
|
inputs = inputs.to('cuda') |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=256, |
|
num_beams=4, |
|
early_stopping=True |
|
) |
|
|
|
sql = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
sql = clean_sql(sql) |
|
|
|
if not is_safe_sql(sql): |
|
return jsonify({"error": "تم توليد استعلام غير آمن"}), 400 |
|
|
|
conn = get_db_connection() |
|
if not conn: |
|
return jsonify({"error": "فشل الاتصال بقاعدة البيانات"}), 500 |
|
|
|
try: |
|
cursor = conn.cursor() |
|
cursor.execute(sql) |
|
|
|
columns = [desc[0] for desc in cursor.description] if cursor.description else [] |
|
rows = cursor.fetchall() |
|
|
|
return jsonify({ |
|
"data": [dict(zip(columns, row)) for row in rows], |
|
"sql": sql, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
except Exception as e: |
|
return jsonify({ |
|
"error": "خطأ في تنفيذ الاستعلام", |
|
"sql": sql, |
|
"details": str(e) |
|
}), 500 |
|
|
|
finally: |
|
if conn: |
|
conn.close() |
|
|
|
except Exception as e: |
|
logger.error(f"❌ خطأ غير متوقع: {str(e)}", exc_info=True) |
|
return jsonify({"error": "فشل في معالجة الطلب"}), 500 |
|
|
|
try: |
|
cursor = conn.cursor() |
|
cursor.execute(sql) |
|
|
|
if cursor.description: |
|
columns = [desc[0] for desc in cursor.description] |
|
rows = cursor.fetchall() |
|
data = [dict(zip(columns, row)) for row in rows] |
|
|
|
response = { |
|
"data": data, |
|
"sql": sql, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
else: |
|
conn.commit() |
|
response = { |
|
"message": "تم تنفيذ الاستعلام بنجاح (لا توجد بيانات للإرجاع)", |
|
"sql": sql |
|
} |
|
|
|
return jsonify(response) |
|
|
|
except Exception as e: |
|
logger.error(f"❌ خطأ في تنفيذ SQL: {e}\nالاستعلام: {sql}") |
|
return jsonify({ |
|
"error": "خطأ في تنفيذ الاستعلام", |
|
"sql": sql, |
|
"details": str(e) |
|
}), 500 |
|
|
|
finally: |
|
if conn: |
|
conn.close() |
|
|
|
except Exception as e: |
|
logger.error(f"❌ خطأ غير متوقع: {str(e)}", exc_info=True) |
|
return jsonify({"error": "فشل في معالجة الطلب"}), 500 |
|
|
|
@app.route('/health') |
|
def health_check(): |
|
return jsonify({ |
|
"status": "healthy", |
|
"model_loaded": model is not None, |
|
"device": str(model.device) if model else "none", |
|
"db_connection": get_db_connection() is not None, |
|
"timestamp": datetime.now().isoformat() |
|
}) |
|
|
|
@app.route('/') |
|
def home(): |
|
return """ |
|
<h1>Text2SQL API</h1> |
|
<p>استخدم <code>/api/query</code> مع POST</p> |
|
<p>تحقق من الحالة: <a href="/health">/health</a></p> |
|
""" |
|
|
|
if __name__ == '__main__': |
|
app.run(host='0.0.0.0', port=7860) |