MHD011 commited on
Commit
5f253ea
·
verified ·
1 Parent(s): 37a2315

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -232,14 +232,12 @@ def handle_query():
232
  ### استعلام SQL (SELECT فقط):
233
  """
234
 
 
235
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
236
 
237
- # تحويل المدخلات إلى float32 إذا كنا على CPU
238
- if not torch.cuda.is_available():
239
- inputs = inputs.to(torch.float32)
240
-
241
- if torch.cuda.is_available():
242
- inputs = inputs.to('cuda')
243
 
244
  with torch.no_grad():
245
  outputs = model.generate(
@@ -250,6 +248,8 @@ def handle_query():
250
  )
251
 
252
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
253
  sql = clean_sql(sql)
254
 
255
  if not is_safe_sql(sql):
 
232
  ### استعلام SQL (SELECT فقط):
233
  """
234
 
235
+ # تحضير المدخلات بدون تحويل نوع البيانات مباشرة
236
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
237
 
238
+ # نقل المدخلات إلى الجهاز المناسب (CPU/GPU)
239
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
240
+ inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
241
 
242
  with torch.no_grad():
243
  outputs = model.generate(
 
248
  )
249
 
250
  sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
251
+ logger.info(f"⚡ الاستعلام المولد: {sql}")
252
+
253
  sql = clean_sql(sql)
254
 
255
  if not is_safe_sql(sql):