Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModel # Entfernt: FlaxAutoModelForSeq2SeqLM | |
import torch | |
import numpy as np | |
import random | |
import json | |
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
# Lade NUR RecipeBERT Modell | |
bert_model_name = "alexdseo/RecipeBERT" | |
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name) | |
bert_model = AutoModel.from_pretrained(bert_model_name) | |
bert_model.eval() # Setze das Modell in den Evaluationsmodus | |
# T5-Modell und -Logik KOMPLETT ENTFERNT für diesen Schritt | |
# special_tokens und tokens_map sind nicht mehr relevant, bleiben aber als Kommentar | |
# --- RecipeBERT-spezifische Funktionen (die jetzt die Kernlogik sind) --- | |
def get_embedding(text): | |
"""Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens.""" | |
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = bert_model(**inputs) | |
attention_mask = inputs['attention_mask'] | |
token_embeddings = outputs.last_hidden_state | |
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() | |
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) | |
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) | |
return (sum_embeddings / sum_mask).squeeze(0) | |
def get_cosine_similarity(vec1, vec2): | |
"""Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren.""" | |
if torch.is_tensor(vec1): vec1 = vec1.detach().numpy() | |
if torch.is_tensor(vec2): vec2 = vec2.detach().numpy() | |
vec1 = vec1.flatten() | |
vec2 = vec2.flatten() | |
dot_product = np.dot(vec1, vec2) | |
norm_a = np.linalg.norm(vec1) | |
norm_b = np.linalg.norm(vec2) | |
if norm_a == 0 or norm_b == 0: return 0 | |
return dot_product / (norm_a * norm_b) | |
# find_best_ingredients (modifiziert, um KEINE Embeddings für T5-ähnliche Auswahl zu nutzen, | |
# sondern nur grundlegende Zutatenbearbeitung und Optionalen Test für RecipeBERT-Laden) | |
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6): | |
""" | |
Für diesen Test: Gibt einfach die benötigten Zutaten plus ein paar zufällige verfügbare Zutaten zurück. | |
Die semantische Auswahl von RecipeBERT ist hier nicht aktiv (da T5-Generierung fehlt). | |
""" | |
required_ingredients = list(set(required_ingredients)) | |
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients])) | |
final_ingredients = required_ingredients.copy() | |
num_to_add = min(max_ingredients - len(final_ingredients), len(available_ingredients)) | |
if num_to_add > 0: | |
final_ingredients.extend(random.sample(available_ingredients, num_to_add)) | |
# Optional: Ein kleiner Test-Print, ob RecipeBERT erfolgreich geladen wurde | |
try: | |
if final_ingredients: | |
# Versuche ein Embedding für die erste Zutat zu generieren | |
test_embedding = get_embedding(final_ingredients[0]) | |
print(f"INFO: Successfully generated embedding for '{final_ingredients[0]}'. RecipeBERT is loaded.") | |
else: | |
print("INFO: No ingredients to test embedding with.") | |
except Exception as e: | |
print(f"ERROR: RecipeBERT embedding test failed: {e}") | |
return final_ingredients | |
# mock_generate_recipe (ersetzt generate_recipe_with_t5) | |
def mock_generate_recipe(ingredients_list): | |
"""Generiert ein Mock-Rezept, da T5-Modell entfernt ist.""" | |
title = f"Einfaches Rezept mit {', '.join(ingredients_list[:3])}" if ingredients_list else "Einfaches Testrezept" | |
return { | |
"title": title, | |
"ingredients": ingredients_list, # Die "generierten" Zutaten sind einfach die Eingabe | |
"directions": [ | |
"Dies ist ein generierter Text von RecipeBERT (ohne T5).", | |
"Fügen Sie Ihre Zutaten zusammen und kochen Sie es nach Belieben.", | |
"Das Laden des RecipeBERT-Modells war erfolgreich!" | |
] | |
} | |
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries): | |
""" | |
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage. | |
Für diesen Test wird nur RecipeBERT zum Laden getestet und ein Mock-Rezept zurückgegeben. | |
""" | |
if not required_ingredients and not available_ingredients: | |
return {"error": "Keine Zutaten angegeben"} | |
try: | |
optimized_ingredients = find_best_ingredients( | |
required_ingredients, available_ingredients, max_ingredients | |
) | |
# Rufe die Mock-Generierungsfunktion auf | |
recipe = mock_generate_recipe(optimized_ingredients) | |
result = { | |
'title': recipe['title'], | |
'ingredients': recipe['ingredients'], | |
'directions': recipe['directions'], | |
'used_ingredients': optimized_ingredients | |
} | |
return result | |
except Exception as e: | |
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"} | |
# --- FastAPI-Implementierung --- | |
app = FastAPI(title="AI Recipe Generator API (RecipeBERT Only Test)") | |
class RecipeRequest(BaseModel): | |
required_ingredients: list[str] = [] | |
available_ingredients: list[str] = [] | |
max_ingredients: int = 7 | |
max_retries: int = 5 | |
ingredients: list[str] = [] # Für Abwärtskompatibilität | |
# Der API-Endpunkt für Flutter | |
async def generate_recipe_api(request_data: RecipeRequest): | |
final_required_ingredients = request_data.required_ingredients | |
if not final_required_ingredients and request_data.ingredients: | |
final_required_ingredients = request_data.ingredients | |
result_dict = process_recipe_request_logic( | |
final_required_ingredients, | |
request_data.available_ingredients, | |
request_data.max_ingredients, | |
request_data.max_retries | |
) | |
return JSONResponse(content=result_dict) | |
async def read_root(): | |
return {"message": "AI Recipe Generator API is running (RecipeBERT only)!"} # Angepasste Nachricht | |
print("INFO: FastAPI application script finished execution and defined 'app' variable.") | |