model-bot / app.py
MHD011's picture
Update app.py
37a2315 verified
raw
history blame
11.5 kB
import os
import re
import logging
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import psycopg2
from datetime import datetime
app = Flask(__name__)
CORS(app)
# إعدادات التسجيل
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# إعدادات النموذج
MODEL_NAME = "tscholak/3vnuv1vf" # نموذج متخصص لـ PostgreSQL
SUPABASE_DB_URL = os.getenv('SUPABASE_DB_URL')
# تهيئة مجلد الذاكرة المؤقتة
os.makedirs("model_cache", exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = "model_cache"
os.environ["TORCH_CACHE"] = "model_cache"
tokenizer = None
model = None
def initialize():
global tokenizer, model
try:
logger.info("🚀 جاري تحميل النموذج...")
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
cache_dir="model_cache"
)
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_NAME,
cache_dir="model_cache",
device_map="auto",
torch_dtype=torch.float32, # تغيير من float16 إلى float32 للإصلاح
low_cpu_mem_usage=True
)
if torch.cuda.is_available():
model.to('cuda')
logger.info("✅ تم تحميل النموذج على GPU")
else:
model.float() # تأكد من تحويل النموذج إلى float32 على CPU
logger.info("✅ تم تحميل النموذج على CPU (باستخدام float32)")
model.eval()
except Exception as e:
logger.error(f"❌ فشل في تحميل النموذج: {str(e)}")
raise
initialize()
# سكيما قاعدة البيانات (نفسها كما في الكود الأصلي)
DB_SCHEMA = """
CREATE TABLE public.profiles (
id uuid NOT NULL,
updated_at timestamp with time zone,
username text UNIQUE CHECK (char_length(username) >= 3),
full_name text,
avatar_url text,
website text,
cam_mac text UNIQUE,
fcm_token text,
notification_enabled boolean DEFAULT true,
CONSTRAINT profiles_pkey PRIMARY KEY (id),
CONSTRAINT profiles_id_fkey FOREIGN KEY (id) REFERENCES auth.users(id)
);
CREATE TABLE public.place (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp with time zone DEFAULT (now() AT TIME ZONE 'utc'::text),
name text,
CONSTRAINT place_pkey PRIMARY KEY (id)
);
CREATE TABLE public.user_place (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT now(),
place_id bigint,
user_cam_mac text,
CONSTRAINT user_place_pkey PRIMARY KEY (id),
CONSTRAINT user_place_place_id_fkey FOREIGN KEY (place_id) REFERENCES public.place(id),
CONSTRAINT user_place_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac)
);
CREATE TABLE public.data (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp without time zone,
caption text,
image_url text,
latitude double precision DEFAULT '36.1833854'::double precision,
longitude double precision DEFAULT '37.1309255'::double precision,
user_place_id bigint,
cam_mac text,
CONSTRAINT data_pkey PRIMARY KEY (id),
CONSTRAINT data_user_place_id_fkey FOREIGN KEY (user_place_id) REFERENCES public.user_place(id)
);
CREATE TABLE public.biodata (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT now(),
mac_address text,
acceleration_x double precision,
acceleration_y double precision,
acceleration_z double precision,
gyro_x double precision,
gyro_y double precision,
gyro_z double precision,
temperature double precision,
CONSTRAINT biodata_pkey PRIMARY KEY (id),
CONSTRAINT biodata_mac_address_fkey FOREIGN KEY (mac_address) REFERENCES public.profiles(cam_mac)
);
CREATE TABLE public.notification (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp without time zone NOT NULL DEFAULT now(),
user_cam_mac text,
title text,
message text,
is_read boolean,
acceleration_x double precision,
acceleration_y double precision,
acceleration_z double precision,
gyro_x double precision,
gyro_y double precision,
gyro_z double precision,
CONSTRAINT notification_pkey PRIMARY KEY (id),
CONSTRAINT notification_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac)
);
CREATE TABLE public.flag (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
flag smallint,
user_mac_address text,
CONSTRAINT flag_pkey PRIMARY KEY (id),
CONSTRAINT flag_user_mac_address_fkey FOREIGN KEY (user_mac_address) REFERENCES public.profiles(cam_mac)
);
""".strip()
def get_db_connection():
try:
conn = psycopg2.connect(SUPABASE_DB_URL)
logger.info("✅ تم الاتصال بقاعدة البيانات")
return conn
except Exception as err:
logger.error(f"❌ خطأ في الاتصال بقاعدة البيانات: {err}")
return None
def validate_cam_mac(cam_mac):
conn = get_db_connection()
if not conn:
return False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM profiles WHERE cam_mac = %s;", (cam_mac,))
return cursor.fetchone() is not None
except Exception as e:
logger.error(f"❌ خطأ في التحقق من cam_mac: {e}")
return False
finally:
if conn:
conn.close()
def is_safe_sql(sql):
"""تحقق مما إذا كان الاستعلام SELECT فقط وآمن للتنفيذ"""
sql_upper = sql.upper()
# قائمة بالأوامر المحظورة
forbidden_commands = [
"INSERT", "UPDATE", "DELETE", "CREATE",
"DROP", "ALTER", "TRUNCATE", "GRANT",
"REVOKE", "COMMIT", "ROLLBACK"
]
# التأكد من أن الاستعلام يبدأ بـ SELECT
if not sql_upper.strip().startswith("SELECT"):
return False
# التأكد من عدم وجود أي أوامر محظورة
for cmd in forbidden_commands:
if cmd in sql_upper:
return False
return True
def clean_sql(sql):
"""تنظيف استعلام SQL لضمان أنه SELECT فقط"""
# إزالة أي شيء قبل SELECT
sql = re.sub(r'^[^S]*(SELECT)', 'SELECT', sql, flags=re.IGNORECASE)
# أخذ أول عبارة SQL فقط (تجاهل أي شيء بعد ;)
sql = sql.split(';')[0] + ';'
# إزالة المسافات الزائدة
sql = ' '.join(sql.split()).strip()
return sql
@app.route('/api/query', methods=['POST'])
def handle_query():
try:
data = request.get_json()
logger.info(f"📩 بيانات الطلب: {data}")
if not data or 'text' not in data or 'cam_mac' not in data:
return jsonify({"error": "يرجى إرسال 'text' و 'cam_mac'"}), 400
if not validate_cam_mac(data['cam_mac']):
return jsonify({"error": "عنوان MAC غير صالح"}), 403
prompt = f"""
### التعليمات الصارمة:
1. قم بتحويل السؤال إلى استعلام SELECT فقط لـ PostgreSQL.
2. يجب أن يتضمن الشرط: WHERE cam_mac = '{data['cam_mac']}'.
3. ممنوع تمامًا استخدام أي أوامر غير SELECT.
### هيكل قاعدة البيانات:
{DB_SCHEMA}
### السؤال:
{data['text']}
### استعلام SQL (SELECT فقط):
"""
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
# تحويل المدخلات إلى float32 إذا كنا على CPU
if not torch.cuda.is_available():
inputs = inputs.to(torch.float32)
if torch.cuda.is_available():
inputs = inputs.to('cuda')
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=256,
num_beams=4,
early_stopping=True
)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
sql = clean_sql(sql)
if not is_safe_sql(sql):
return jsonify({"error": "تم توليد استعلام غير آمن"}), 400
conn = get_db_connection()
if not conn:
return jsonify({"error": "فشل الاتصال بقاعدة البيانات"}), 500
try:
cursor = conn.cursor()
cursor.execute(sql)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = cursor.fetchall()
return jsonify({
"data": [dict(zip(columns, row)) for row in rows],
"sql": sql,
"timestamp": datetime.now().isoformat()
})
except Exception as e:
return jsonify({
"error": "خطأ في تنفيذ الاستعلام",
"sql": sql,
"details": str(e)
}), 500
finally:
if conn:
conn.close()
except Exception as e:
logger.error(f"❌ خطأ غير متوقع: {str(e)}", exc_info=True)
return jsonify({"error": "فشل في معالجة الطلب"}), 500
try:
cursor = conn.cursor()
cursor.execute(sql)
if cursor.description:
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
data = [dict(zip(columns, row)) for row in rows]
response = {
"data": data,
"sql": sql,
"timestamp": datetime.now().isoformat()
}
else:
conn.commit()
response = {
"message": "تم تنفيذ الاستعلام بنجاح (لا توجد بيانات للإرجاع)",
"sql": sql
}
return jsonify(response)
except Exception as e:
logger.error(f"❌ خطأ في تنفيذ SQL: {e}\nالاستعلام: {sql}")
return jsonify({
"error": "خطأ في تنفيذ الاستعلام",
"sql": sql,
"details": str(e)
}), 500
finally:
if conn:
conn.close()
except Exception as e:
logger.error(f"❌ خطأ غير متوقع: {str(e)}", exc_info=True)
return jsonify({"error": "فشل في معالجة الطلب"}), 500
@app.route('/health')
def health_check():
return jsonify({
"status": "healthy",
"model_loaded": model is not None,
"device": str(model.device) if model else "none",
"db_connection": get_db_connection() is not None,
"timestamp": datetime.now().isoformat()
})
@app.route('/')
def home():
return """
<h1>Text2SQL API</h1>
<p>استخدم <code>/api/query</code> مع POST</p>
<p>تحقق من الحالة: <a href="/health">/health</a></p>
"""
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)