|
import os
|
|
import re
|
|
import logging
|
|
from flask import Flask, request, jsonify, Response
|
|
from flask_cors import CORS
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
import torch
|
|
import psycopg2
|
|
import json
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
MODEL_NAME = "tscholak/1wnr2e8q"
|
|
SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
|
|
|
|
|
|
os.makedirs("model_cache", exist_ok=True)
|
|
os.environ["TRANSFORMERS_CACHE"] = "model_cache"
|
|
os.environ["TORCH_CACHE"] = "model_cache"
|
|
|
|
tokenizer = None
|
|
model = None
|
|
|
|
def initialize():
|
|
global tokenizer, model
|
|
try:
|
|
logger.info("جاري تحميل النموذج بتهيئة منخفضة الذاكرة...")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
MODEL_NAME,
|
|
cache_dir="model_cache",
|
|
local_files_only=False
|
|
)
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
MODEL_NAME,
|
|
cache_dir="model_cache",
|
|
device_map="auto",
|
|
torch_dtype=torch.float16,
|
|
low_cpu_mem_usage=True
|
|
)
|
|
model.eval()
|
|
logger.info("تم تحميل النموذج بنجاح")
|
|
except Exception as e:
|
|
logger.error(f"فشل في تحميل النموذج: {str(e)}")
|
|
raise
|
|
|
|
initialize()
|
|
|
|
DB_SCHEMA = """
|
|
CREATE TABLE public.profiles (
|
|
id uuid NOT NULL,
|
|
updated_at timestamp with time zone,
|
|
username text UNIQUE CHECK (char_length(username) >= 3),
|
|
full_name text,
|
|
avatar_url text,
|
|
website text,
|
|
cam_mac text UNIQUE,
|
|
fcm_token text,
|
|
notification_enabled boolean DEFAULT true,
|
|
CONSTRAINT profiles_pkey PRIMARY KEY (id),
|
|
CONSTRAINT profiles_id_fkey FOREIGN KEY (id) REFERENCES auth.users(id)
|
|
);
|
|
|
|
CREATE TABLE public.place (
|
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
created_at timestamp with time zone DEFAULT (now() AT TIME ZONE 'utc'::text),
|
|
name text,
|
|
CONSTRAINT place_pkey PRIMARY KEY (id)
|
|
);
|
|
|
|
CREATE TABLE public.user_place (
|
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
created_at timestamp with time zone NOT NULL DEFAULT now(),
|
|
place_id bigint,
|
|
user_cam_mac text,
|
|
CONSTRAINT user_place_pkey PRIMARY KEY (id),
|
|
CONSTRAINT user_place_place_id_fkey FOREIGN KEY (place_id) REFERENCES public.place(id),
|
|
CONSTRAINT user_place_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac)
|
|
);
|
|
|
|
CREATE TABLE public.data (
|
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
created_at timestamp without time zone,
|
|
caption text,
|
|
image_url text,
|
|
latitude double precision DEFAULT '36.1833854'::double precision,
|
|
longitude double precision DEFAULT '37.1309255'::double precision,
|
|
user_place_id bigint,
|
|
cam_mac text,
|
|
CONSTRAINT data_pkey PRIMARY KEY (id),
|
|
CONSTRAINT data_user_place_id_fkey FOREIGN KEY (user_place_id) REFERENCES public.user_place(id)
|
|
);
|
|
|
|
CREATE TABLE public.biodata (
|
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
created_at timestamp with time zone NOT NULL DEFAULT now(),
|
|
mac_address text,
|
|
acceleration_x double precision,
|
|
acceleration_y double precision,
|
|
acceleration_z double precision,
|
|
gyro_x double precision,
|
|
gyro_y double precision,
|
|
gyro_z double precision,
|
|
temperature double precision,
|
|
CONSTRAINT biodata_pkey PRIMARY KEY (id),
|
|
CONSTRAINT biodata_mac_address_fkey FOREIGN KEY (mac_address) REFERENCES public.profiles(cam_mac)
|
|
);
|
|
|
|
CREATE TABLE public.notification (
|
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
created_at timestamp without time zone NOT NULL DEFAULT now(),
|
|
user_cam_mac text,
|
|
title text,
|
|
message text,
|
|
is_read boolean,
|
|
acceleration_x double precision,
|
|
acceleration_y double precision,
|
|
acceleration_z double precision,
|
|
gyro_x double precision,
|
|
gyro_y double precision,
|
|
gyro_z double precision,
|
|
CONSTRAINT notification_pkey PRIMARY KEY (id),
|
|
CONSTRAINT notification_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac)
|
|
);
|
|
|
|
CREATE TABLE public.flag (
|
|
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
|
|
flag smallint,
|
|
user_mac_address text,
|
|
CONSTRAINT flag_pkey PRIMARY KEY (id),
|
|
CONSTRAINT flag_user_mac_address_fkey FOREIGN KEY (user_mac_address) REFERENCES public.profiles(cam_mac)
|
|
);
|
|
""".strip()
|
|
|
|
|
|
def get_db_connection():
|
|
try:
|
|
return psycopg2.connect(SUPABASE_DB_URL)
|
|
except Exception as err:
|
|
logger.error(f"Database connection error: {err}")
|
|
return None
|
|
|
|
|
|
def validate_cam_mac(cam_mac):
|
|
conn = get_db_connection()
|
|
if not conn:
|
|
return False
|
|
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT 1 FROM profiles WHERE cam_mac = %s;", (cam_mac,))
|
|
return cursor.fetchone() is not None
|
|
except Exception as e:
|
|
logger.error(f"Validation error: {e}")
|
|
return False
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
|
|
@app.route('/api/query', methods=['POST'])
|
|
def handle_query():
|
|
if tokenizer is None or model is None:
|
|
return jsonify({"error": "النموذج غير محمل، يرجى المحاولة لاحقاً"}), 503
|
|
|
|
try:
|
|
data = request.get_json()
|
|
if not data or 'text' not in data or 'cam_mac' not in data:
|
|
return jsonify({"error": "يرجى إرسال 'text' و 'cam_mac'"}), 400
|
|
|
|
natural_query = data['text']
|
|
cam_mac = data['cam_mac']
|
|
logger.info(f"استعلام من {cam_mac}: {natural_query}")
|
|
|
|
if not validate_cam_mac(cam_mac):
|
|
return jsonify({"error": "عنوان MAC غير صالح"}), 403
|
|
|
|
prompt = f"""
|
|
### Postgres SQL table definitions
|
|
{DB_SCHEMA}
|
|
|
|
### Rules:
|
|
- Always filter by cam_mac = '{cam_mac}'
|
|
- Use only SELECT statements
|
|
- Use proper JOINs
|
|
- Use table aliases when helpful
|
|
- The output must contain only the SQL query
|
|
|
|
### User question: {natural_query}
|
|
|
|
### SQL query:
|
|
SELECT
|
|
""".strip()
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(**inputs, max_length=256)
|
|
|
|
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
|
sql = re.sub(r"^```sql\s*", "", sql, flags=re.IGNORECASE)
|
|
sql = re.sub(r"\s*```$", "", sql)
|
|
sql = re.sub(r"^SQL:\s*", "", sql, flags=re.IGNORECASE)
|
|
|
|
if not sql.upper().startswith("SELECT"):
|
|
sql = "SELECT " + sql.split("SELECT")[-1] if "SELECT" in sql else f"SELECT * FROM ({sql}) AS subquery"
|
|
|
|
if not sql.endswith(";"):
|
|
sql += ";"
|
|
|
|
logger.info(f"استعلام SQL المولد: {sql}")
|
|
|
|
if not sql.upper().strip().startswith("SELECT"):
|
|
return jsonify({"error": "يُسمح فقط باستعلامات SELECT"}), 403
|
|
|
|
conn = get_db_connection()
|
|
if not conn:
|
|
return jsonify({"error": "فشل الاتصال بقاعدة البيانات"}), 500
|
|
|
|
cursor = None
|
|
try:
|
|
cursor = conn.cursor()
|
|
cursor.execute(sql)
|
|
columns = [desc[0] for desc in cursor.description]
|
|
rows = cursor.fetchall()
|
|
data = [dict(zip(columns, row)) for row in rows]
|
|
|
|
response_data = {
|
|
"data": data,
|
|
"generated_sql": sql
|
|
}
|
|
|
|
response_json = json.dumps(response_data, ensure_ascii=False)
|
|
|
|
return Response(
|
|
response_json,
|
|
status=200,
|
|
mimetype='application/json; charset=utf-8'
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"خطأ في تنفيذ SQL: {e}")
|
|
return jsonify({
|
|
"error": str(e),
|
|
"generated_sql": sql
|
|
}), 500
|
|
finally:
|
|
if cursor:
|
|
cursor.close()
|
|
if conn:
|
|
conn.close()
|
|
|
|
except Exception as e:
|
|
logger.error(f"خطأ في التوليد: {str(e)}")
|
|
return jsonify({"error": "فشل في توليد الاستعلام"}), 500
|
|
|
|
@app.route('/health')
|
|
def health_check():
|
|
return jsonify({
|
|
"status": "healthy",
|
|
"model_loaded": model is not None,
|
|
"db_connection": get_db_connection() is not None
|
|
})
|
|
|
|
@app.route('/')
|
|
def home():
|
|
return """
|
|
<h1>Text2SQL API</h1>
|
|
<p>استخدم <code>/api/query</code> مع POST {"text": "سؤالك", "cam_mac": "عنوان MAC"}</p>
|
|
<p>تحقق من حالة الخدمة: <a href="/health">/health</a></p>
|
|
"""
|
|
|
|
if __name__ == '__main__':
|
|
app.run(host='0.0.0.0', port=7860) |