esp32-chat-bot / app.py
MHD011's picture
Upload 2 files
2b295d5 verified
raw
history blame
8.83 kB
import os
import re
import logging
from flask import Flask, request, jsonify, Response
from flask_cors import CORS
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import psycopg2
import json
app = Flask(__name__)
CORS(app)
# --- إعداد السجل ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- إعداد النموذج ---
MODEL_NAME = "tscholak/cxmefzzi" # نموذج Text-to-SQL بديل
SUPABASE_DB_URL = "postgresql://postgres.mougnkvoyyhcuxeeqvmh:Xf5E0DhUvKEHEAqq@aws-0-eu-central-1.pooler.supabase.com:6543/postgres"
tokenizer = None
model = None
def initialize():
global tokenizer, model
# تحديد مسار ذاكرة مؤقتة ضمن المساحة المسموح بها
cache_dir = os.path.join(os.getcwd(), "model_cache")
os.makedirs(cache_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"تحميل النموذج على الجهاز: {device}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=cache_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, cache_dir=cache_dir).to(device)
logger.info("تم تحميل النموذج بنجاح")
except Exception as e:
logger.error(f"فشل في تحميل النموذج: {str(e)}")
raise
initialize()
# --- سكيمة قاعدة البيانات ---
DB_SCHEMA = """
CREATE TABLE public.profiles (
id uuid NOT NULL,
updated_at timestamp with time zone,
username text UNIQUE CHECK (char_length(username) >= 3),
full_name text,
avatar_url text,
website text,
cam_mac text UNIQUE,
fcm_token text,
notification_enabled boolean DEFAULT true,
CONSTRAINT profiles_pkey PRIMARY KEY (id),
CONSTRAINT profiles_id_fkey FOREIGN KEY (id) REFERENCES auth.users(id)
);
CREATE TABLE public.place (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp with time zone DEFAULT (now() AT TIME ZONE 'utc'::text),
name text,
CONSTRAINT place_pkey PRIMARY KEY (id)
);
CREATE TABLE public.user_place (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT now(),
place_id bigint,
user_cam_mac text,
CONSTRAINT user_place_pkey PRIMARY KEY (id),
CONSTRAINT user_place_place_id_fkey FOREIGN KEY (place_id) REFERENCES public.place(id),
CONSTRAINT user_place_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac)
);
CREATE TABLE public.data (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp without time zone,
caption text,
image_url text,
latitude double precision DEFAULT '36.1833854'::double precision,
longitude double precision DEFAULT '37.1309255'::double precision,
user_place_id bigint,
cam_mac text,
CONSTRAINT data_pkey PRIMARY KEY (id),
CONSTRAINT data_user_place_id_fkey FOREIGN KEY (user_place_id) REFERENCES public.user_place(id)
);
CREATE TABLE public.biodata (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp with time zone NOT NULL DEFAULT now(),
mac_address text,
acceleration_x double precision,
acceleration_y double precision,
acceleration_z double precision,
gyro_x double precision,
gyro_y double precision,
gyro_z double precision,
temperature double precision,
CONSTRAINT biodata_pkey PRIMARY KEY (id),
CONSTRAINT biodata_mac_address_fkey FOREIGN KEY (mac_address) REFERENCES public.profiles(cam_mac)
);
CREATE TABLE public.notification (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
created_at timestamp without time zone NOT NULL DEFAULT now(),
user_cam_mac text,
title text,
message text,
is_read boolean,
acceleration_x double precision,
acceleration_y double precision,
acceleration_z double precision,
gyro_x double precision,
gyro_y double precision,
gyro_z double precision,
CONSTRAINT notification_pkey PRIMARY KEY (id),
CONSTRAINT notification_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac)
);
CREATE TABLE public.flag (
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL,
flag smallint,
user_mac_address text,
CONSTRAINT flag_pkey PRIMARY KEY (id),
CONSTRAINT flag_user_mac_address_fkey FOREIGN KEY (user_mac_address) REFERENCES public.profiles(cam_mac)
);
""".strip()
# --- الاتصال بقاعدة البيانات ---
def get_db_connection():
try:
return psycopg2.connect(SUPABASE_DB_URL)
except Exception as err:
logger.error(f"Database connection error: {err}")
return None
# --- التحقق من صحة cam_mac ---
def validate_cam_mac(cam_mac):
conn = get_db_connection()
if not conn:
return False
try:
cursor = conn.cursor()
cursor.execute("SELECT 1 FROM profiles WHERE cam_mac = %s;", (cam_mac,))
return cursor.fetchone() is not None
except Exception as e:
logger.error(f"Validation error: {e}")
return False
finally:
if conn:
conn.close()
@app.route('/api/query', methods=['POST'])
def handle_query():
if tokenizer is None or model is None:
return jsonify({"error": "النموذج غير محمل، يرجى المحاولة لاحقاً"}), 503
try:
data = request.get_json()
if not data or 'text' not in data or 'cam_mac' not in data:
return jsonify({"error": "يرجى إرسال 'text' و 'cam_mac'"}), 400
natural_query = data['text']
cam_mac = data['cam_mac']
logger.info(f"استعلام من {cam_mac}: {natural_query}")
# التحقق من صحة cam_mac
if not validate_cam_mac(cam_mac):
return jsonify({"error": "عنوان MAC غير صالح"}), 403
prompt = f"""
### Postgres SQL table definitions
{DB_SCHEMA}
### Rules:
- Always filter by cam_mac = '{cam_mac}'
- Use only SELECT statements
- Use proper JOINs
- Use table aliases when helpful
- The output must contain only the SQL query
### User question: {natural_query}
### SQL query:
SELECT
""".strip()
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=256)
sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
# تنظيف الناتج
sql = re.sub(r"^```sql\s*", "", sql, flags=re.IGNORECASE)
sql = re.sub(r"\s*```$", "", sql)
sql = re.sub(r"^SQL:\s*", "", sql, flags=re.IGNORECASE)
if not sql.upper().startswith("SELECT"):
sql = "SELECT " + sql.split("SELECT")[-1] if "SELECT" in sql else f"SELECT * FROM ({sql}) AS subquery"
if not sql.endswith(";"):
sql += ";"
logger.info(f"استعلام SQL المولد: {sql}")
if not sql.upper().strip().startswith("SELECT"):
return jsonify({"error": "يُسمح فقط باستعلامات SELECT"}), 403
conn = get_db_connection()
if not conn:
return jsonify({"error": "فشل الاتصال بقاعدة البيانات"}), 500
cursor = None
try:
cursor = conn.cursor()
cursor.execute(sql)
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
data = [dict(zip(columns, row)) for row in rows]
response_data = {
"data": data,
}
response_json = json.dumps(response_data, ensure_ascii=False)
return Response(
response_json,
status=200,
mimetype='application/json; charset=utf-8'
)
except Exception as e:
logger.error(f"خطأ في تنفيذ SQL: {e}")
return jsonify({"error": str(e), "generated_sql": sql}), 500
finally:
if cursor:
cursor.close()
if conn:
conn.close()
except Exception as e:
logger.error(f"خطأ في التوليد: {str(e)}")
return jsonify({"error": "فشل في توليد الاستعلام"}), 500
@app.route('/')
def home():
return """
<h1>Text2SQL API</h1>
<p>Use <code>/api/query</code> with POST {"text": "your question", "cam_mac": "device_mac_address"}.</p>
"""
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)