Spaces:
Running
Running
import os | |
import logging | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
app = Flask(__name__) | |
CORS(app) | |
# --- إعداد السجل --- | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# --- إعداد النموذج --- | |
MODEL_NAME = "tscholak/sqlcoder" # يمكنك تغييره لنموذج آخر إذا رغبت | |
tokenizer = None | |
model = None | |
def initialize(): | |
global tokenizer, model | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"تحميل النموذج على الجهاز: {device}") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME).to(device) | |
logger.info("تم تحميل النموذج بنجاح") | |
initialize() | |
# --- سكيمة قاعدة البيانات (كمثال) --- | |
DB_SCHEMA = """ | |
CREATE TABLE public.profiles ( | |
id uuid NOT NULL, | |
updated_at timestamp with time zone, | |
username text UNIQUE CHECK (char_length(username) >= 3), | |
full_name text, | |
avatar_url text, | |
website text, | |
cam_mac text UNIQUE, | |
fcm_token text, | |
notification_enabled boolean DEFAULT true, | |
CONSTRAINT profiles_pkey PRIMARY KEY (id), | |
CONSTRAINT profiles_id_fkey FOREIGN KEY (id) REFERENCES auth.users(id) | |
); | |
CREATE TABLE public.place ( | |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, | |
created_at timestamp with time zone DEFAULT (now() AT TIME ZONE 'utc'::text), | |
name text, | |
CONSTRAINT place_pkey PRIMARY KEY (id) | |
); | |
CREATE TABLE public.user_place ( | |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, | |
created_at timestamp with time zone NOT NULL DEFAULT now(), | |
place_id bigint, | |
user_cam_mac text, | |
CONSTRAINT user_place_pkey PRIMARY KEY (id), | |
CONSTRAINT user_place_place_id_fkey FOREIGN KEY (place_id) REFERENCES public.place(id), | |
CONSTRAINT user_place_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac) | |
); | |
CREATE TABLE public.data ( | |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, | |
created_at timestamp without time zone, | |
caption text, | |
image_url text, | |
latitude double precision DEFAULT '36.1833854'::double precision, | |
longitude double precision DEFAULT '37.1309255'::double precision, | |
user_place_id bigint, | |
cam_mac text, | |
CONSTRAINT data_pkey PRIMARY KEY (id), | |
CONSTRAINT data_user_place_id_fkey FOREIGN KEY (user_place_id) REFERENCES public.user_place(id) | |
); | |
CREATE TABLE public.biodata ( | |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, | |
created_at timestamp with time zone NOT NULL DEFAULT now(), | |
mac_address text, | |
acceleration_x double precision, | |
acceleration_y double precision, | |
acceleration_z double precision, | |
gyro_x double precision, | |
gyro_y double precision, | |
gyro_z double precision, | |
temperature double precision, | |
CONSTRAINT biodata_pkey PRIMARY KEY (id), | |
CONSTRAINT biodata_mac_address_fkey FOREIGN KEY (mac_address) REFERENCES public.profiles(cam_mac) | |
); | |
CREATE TABLE public.notification ( | |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, | |
created_at timestamp without time zone NOT NULL DEFAULT now(), | |
user_cam_mac text, | |
title text, | |
message text, | |
is_read boolean, | |
acceleration_x double precision, | |
acceleration_y double precision, | |
acceleration_z double precision, | |
gyro_x double precision, | |
gyro_y double precision, | |
gyro_z double precision, | |
CONSTRAINT notification_pkey PRIMARY KEY (id), | |
CONSTRAINT notification_user_cam_mac_fkey FOREIGN KEY (user_cam_mac) REFERENCES public.profiles(cam_mac) | |
); | |
CREATE TABLE public.flag ( | |
id bigint GENERATED ALWAYS AS IDENTITY NOT NULL, | |
flag smallint, | |
user_mac_address text, | |
CONSTRAINT flag_pkey PRIMARY KEY (id), | |
CONSTRAINT flag_user_mac_address_fkey FOREIGN KEY (user_mac_address) REFERENCES public.profiles(cam_mac) | |
); | |
""".strip() | |
def generate_sql(): | |
try: | |
body = request.get_json() | |
user_text = body.get("text", "").strip() | |
cam_mac = body.get("cam_mac", "").strip() | |
if not user_text or not cam_mac: | |
return jsonify({"error": "يرجى إرسال 'text' و 'cam_mac'"}), 400 | |
prompt = f""" | |
### Postgres SQL table definitions | |
{DB_SCHEMA} | |
### User question: {user_text} | |
### SQL query to answer the question filtered by cam_mac = '{cam_mac}': | |
SELECT | |
""".strip() | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
outputs = model.generate(**inputs, max_length=256) | |
sql = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# تنظيف الناتج | |
if not sql.lower().startswith("select"): | |
sql = "SELECT " + sql | |
if not sql.endswith(";"): | |
sql += ";" | |
return jsonify({"sql": sql}) | |
except Exception as e: | |
logger.error(f"خطأ في التوليد: {str(e)}") | |
return jsonify({"error": "فشل في توليد الاستعلام"}), 500 | |
def home(): | |
return """ | |
<h1>Text2SQL API</h1> | |
<p>Send a POST request to <code>/generate-sql</code> with JSON: {"text": "سؤالك", "cam_mac": "عنوان MAC"}</p> | |
""" | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=7860) | |