Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
from transformers import AutoModel, AutoTokenizer | |
from PIL import Image, ImageEnhance, ImageFilter | |
import torch | |
import logging | |
from transformers import BertTokenizer | |
import nltk | |
import requests | |
import io | |
logger = logging.getLogger(__name__) | |
class OCRModel: | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(OCRModel, cls).__new__(cls) | |
cls._instance.initialize() | |
return cls._instance | |
def initialize(self): | |
try: | |
logger.info("Initializing OCR model...") | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
'stepfun-ai/GOT-OCR2_0', | |
trust_remote_code=True, | |
use_fast=False | |
) | |
except Exception as e: | |
logger.warning(f"Standard tokenizer failed, trying BertTokenizer: {str(e)}") | |
self.tokenizer = BertTokenizer.from_pretrained( | |
'stepfun-ai/GOT-OCR2_0', | |
trust_remote_code=True | |
) | |
self.model = AutoModel.from_pretrained( | |
'stepfun-ai/GOT-OCR2_0', | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
device_map='auto', | |
use_safetensors=True | |
) | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.model = self.model.eval().to(self.device) | |
logger.info("Model initialization completed successfully") | |
except Exception as e: | |
logger.error(f"Error initializing model: {str(e)}", exc_info=True) | |
raise | |
def preprocess_image(self, image): | |
"""تحسين جودة الصورة لتحسين استخراج النص""" | |
try: | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
enhancer = ImageEnhance.Contrast(image) | |
image = enhancer.enhance(1.5) | |
enhancer = ImageEnhance.Sharpness(image) | |
image = enhancer.enhance(1.5) | |
enhancer = ImageEnhance.Brightness(image) | |
image = enhancer.enhance(1.2) | |
image = image.filter(ImageFilter.SMOOTH) | |
return image | |
except Exception as e: | |
logger.error(f"Error in image preprocessing: {str(e)}", exc_info=True) | |
raise | |
def process_image(self, image): | |
try: | |
logger.info("Starting image processing") | |
processed_image = self.preprocess_image(image) | |
temp_image_path = "temp_ocr_image.jpg" | |
processed_image.save(temp_image_path) | |
result = self.model.chat(self.tokenizer, temp_image_path, ocr_type='format') | |
logger.info(f"Successfully extracted text: {result[:100]}...") | |
if os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
return result.strip() | |
except Exception as e: | |
logger.error(f"Error in OCR processing: {str(e)}", exc_info=True) | |
if 'temp_image_path' in locals() and os.path.exists(temp_image_path): | |
os.remove(temp_image_path) | |
return f"Error processing image: {str(e)}" | |
class AllergyAnalyzer: | |
def __init__(self, dataset_path): | |
self.dataset_path = dataset_path | |
try: | |
nltk.data.find('tokenizers/punkt') | |
except LookupError: | |
nltk.download('punkt') | |
try: | |
nltk.data.find('tokenizers/punkt_tab') | |
except LookupError: | |
nltk.download('punkt_tab') | |
self.allergy_data = self.load_allergy_data() | |
if self.allergy_data is None: | |
raise ValueError("Failed to load allergy data from dataset") | |
self.ocr_model = OCRModel() | |
def load_allergy_data(self): | |
"""تحميل بيانات الحساسيات من ملف Excel""" | |
try: | |
# قراءة ملف الإكسل مع تحديد أن الصف الأول هو العناوين | |
df = pd.read_excel(self.dataset_path, header=0) | |
allergy_dict = {} | |
for index, row in df.iterrows(): | |
# الحصول على اسم الحساسية من العمود الأول | |
allergy_name = str(row.iloc[0]).strip().lower() | |
if not allergy_name: | |
continue | |
# الحصول على المكونات من الأعمدة التالية | |
ingredients = [] | |
for col in range(1, len(row)): | |
ingredient = str(row.iloc[col]).strip().lower() | |
if ingredient and ingredient != 'nan': | |
ingredients.append(ingredient) | |
allergy_dict[allergy_name] = ingredients | |
logger.info(f"Successfully loaded allergy data with {len(allergy_dict)} categories") | |
return allergy_dict | |
except Exception as e: | |
logger.error(f"Error loading allergy data: {str(e)}", exc_info=True) | |
return None | |
def tokenize_text(self, text): | |
"""تقسيم النص إلى كلمات""" | |
try: | |
tokens = nltk.word_tokenize(text) | |
return [w.lower() for w in tokens if w.isalpha()] | |
except Exception as e: | |
logger.error(f"Error tokenizing text: {str(e)}") | |
return [] | |
def check_allergen_in_excel(self, token, user_allergies): | |
"""التحقق من وجود التوكن في ملف الإكسل مع مراعاة حساسيات المستخدم""" | |
try: | |
if not self.allergy_data: | |
return None | |
for allergy_name, ingredients in self.allergy_data.items(): | |
# نتحقق فقط من الحساسيات التي يهتم بها المستخدم | |
if allergy_name.lower() in user_allergies and token in ingredients: | |
return allergy_name | |
return None | |
except Exception as e: | |
logger.error(f"Error checking allergen in Excel: {str(e)}") | |
return None | |
def check_allergy_risk(self, ingredient, api_key, user_allergies): | |
"""الاستعلام من Claude API عن الحساسيات مع مراعاة حساسيات المستخدم""" | |
try: | |
# نطلب من Claude التحقق فقط للحساسيات المحددة من المستخدم | |
prompt = f""" | |
You are a professional food safety expert. Analyze the ingredient '{ingredient}' and determine if it belongs to any of these allergen categories: | |
{', '.join(user_allergies)}. | |
Respond only with the category name if found or 'None' if not found. | |
""" | |
url = "https://api.anthropic.com/v1/messages" | |
headers = { | |
"x-api-key": api_key, | |
"content-type": "application/json", | |
"anthropic-version": "2023-06-01" | |
} | |
data = { | |
"model": "claude-3-opus-20240229", | |
"messages": [{"role": "user", "content": prompt}], | |
"max_tokens": 10 | |
} | |
response = requests.post(url, json=data, headers=headers) | |
response.raise_for_status() | |
response_json = response.json() | |
if "content" in response_json and isinstance(response_json["content"], list): | |
result = response_json["content"][0]["text"].strip().lower() | |
# نتحقق فقط من الحساسيات التي يهتم بها المستخدم | |
if result in user_allergies: | |
return result | |
return None | |
except Exception as e: | |
logger.error(f"Error querying Claude API: {str(e)}") | |
return None | |
def analyze_image(self, image, claude_api_key=None, user_allergies=None): | |
"""تحليل الصورة للكشف عن الحساسيات مع مراعاة حساسيات المستخدم""" | |
try: | |
if not self.allergy_data: | |
raise ValueError("Allergy data not loaded") | |
if not user_allergies: | |
raise ValueError("User allergies not provided") | |
# استخراج النص من الصورة | |
extracted_text = self.ocr_model.process_image(image) | |
if extracted_text.startswith("Error processing image"): | |
raise ValueError(extracted_text) | |
logger.info(f"Extracted text: {extracted_text[:200]}...") | |
# تحويل النص إلى tokens | |
tokens = self.tokenize_text(extracted_text) | |
if not tokens: | |
raise ValueError("No tokens extracted from text") | |
database_matches = {} | |
claude_matches = {} | |
for token in tokens: | |
# البحث أولاً في قاعدة البيانات للحساسيات المحددة فقط | |
allergy = self.check_allergen_in_excel(token, user_allergies) | |
if allergy: | |
if allergy not in database_matches: | |
database_matches[allergy] = set() # استخدام set لمنع التكرار | |
database_matches[allergy].add(token) | |
elif claude_api_key: | |
# إذا لم يُوجد في ملف الإكسل، استدعِ Claude API للحساسيات المحددة فقط | |
allergy = self.check_allergy_risk(token, claude_api_key, user_allergies) | |
if allergy: | |
if allergy not in claude_matches: | |
claude_matches[allergy] = set() # استخدام set لمنع التكرار | |
claude_matches[allergy].add(token) | |
# إنشاء قائمة الحساسيات المكتشفة مع كل الكلمات المرتبطة بها | |
detected_allergens = [] | |
seen_allergens = set() | |
# إضافة الحساسيات من قاعدة البيانات أولاً | |
for allergy, words in database_matches.items(): | |
if allergy not in seen_allergens: | |
detected_allergens.append({ | |
"allergen": allergy, | |
"related_words": list(words) # تحويل set إلى list | |
}) | |
seen_allergens.add(allergy) | |
# إضافة الحساسيات من Claude API | |
for allergy, words in claude_matches.items(): | |
if allergy not in seen_allergens: | |
detected_allergens.append({ | |
"allergen": allergy, | |
"related_words": list(words) # تحويل set إلى list | |
}) | |
seen_allergens.add(allergy) | |
return { | |
"extracted_text": extracted_text, | |
"detected_allergens": detected_allergens, | |
"database_matches": {k: list(v) for k, v in database_matches.items()}, # تحويل sets إلى lists | |
"claude_matches": {k: list(v) for k, v in claude_matches.items()}, # تحويل sets إلى lists | |
"analyzed_tokens": tokens, | |
"success": True | |
} | |
except Exception as e: | |
logger.error(f"Error analyzing image: {str(e)}", exc_info=True) | |
return { | |
"error": str(e), | |
"success": False | |
} |