DockerRecipe / app.py
TimInf's picture
Update app.py
908127c verified
raw
history blame
6.17 kB
from transformers import AutoTokenizer, AutoModel # Entfernt: FlaxAutoModelForSeq2SeqLM
import torch
import numpy as np
import random
import json
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from pydantic import BaseModel
# Lade NUR RecipeBERT Modell
bert_model_name = "alexdseo/RecipeBERT"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_model_name)
bert_model = AutoModel.from_pretrained(bert_model_name)
bert_model.eval() # Setze das Modell in den Evaluationsmodus
# T5-Modell und -Logik KOMPLETT ENTFERNT für diesen Schritt
# special_tokens und tokens_map sind nicht mehr relevant, bleiben aber als Kommentar
# --- RecipeBERT-spezifische Funktionen (die jetzt die Kernlogik sind) ---
def get_embedding(text):
"""Berechnet das Embedding für einen Text mit Mean Pooling über alle Tokens."""
inputs = bert_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = bert_model(**inputs)
attention_mask = inputs['attention_mask']
token_embeddings = outputs.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return (sum_embeddings / sum_mask).squeeze(0)
def get_cosine_similarity(vec1, vec2):
"""Berechnet die Cosinus-Ähnlichkeit zwischen zwei Vektoren."""
if torch.is_tensor(vec1): vec1 = vec1.detach().numpy()
if torch.is_tensor(vec2): vec2 = vec2.detach().numpy()
vec1 = vec1.flatten()
vec2 = vec2.flatten()
dot_product = np.dot(vec1, vec2)
norm_a = np.linalg.norm(vec1)
norm_b = np.linalg.norm(vec2)
if norm_a == 0 or norm_b == 0: return 0
return dot_product / (norm_a * norm_b)
# find_best_ingredients (modifiziert, um KEINE Embeddings für T5-ähnliche Auswahl zu nutzen,
# sondern nur grundlegende Zutatenbearbeitung und Optionalen Test für RecipeBERT-Laden)
def find_best_ingredients(required_ingredients, available_ingredients, max_ingredients=6, avg_weight=0.6):
"""
Für diesen Test: Gibt einfach die benötigten Zutaten plus ein paar zufällige verfügbare Zutaten zurück.
Die semantische Auswahl von RecipeBERT ist hier nicht aktiv (da T5-Generierung fehlt).
"""
required_ingredients = list(set(required_ingredients))
available_ingredients = list(set([i for i in available_ingredients if i not in required_ingredients]))
final_ingredients = required_ingredients.copy()
num_to_add = min(max_ingredients - len(final_ingredients), len(available_ingredients))
if num_to_add > 0:
final_ingredients.extend(random.sample(available_ingredients, num_to_add))
# Optional: Ein kleiner Test-Print, ob RecipeBERT erfolgreich geladen wurde
try:
if final_ingredients:
# Versuche ein Embedding für die erste Zutat zu generieren
test_embedding = get_embedding(final_ingredients[0])
print(f"INFO: Successfully generated embedding for '{final_ingredients[0]}'. RecipeBERT is loaded.")
else:
print("INFO: No ingredients to test embedding with.")
except Exception as e:
print(f"ERROR: RecipeBERT embedding test failed: {e}")
return final_ingredients
# mock_generate_recipe (ersetzt generate_recipe_with_t5)
def mock_generate_recipe(ingredients_list):
"""Generiert ein Mock-Rezept, da T5-Modell entfernt ist."""
title = f"Einfaches Rezept mit {', '.join(ingredients_list[:3])}" if ingredients_list else "Einfaches Testrezept"
return {
"title": title,
"ingredients": ingredients_list, # Die "generierten" Zutaten sind einfach die Eingabe
"directions": [
"Dies ist ein generierter Text von RecipeBERT (ohne T5).",
"Fügen Sie Ihre Zutaten zusammen und kochen Sie es nach Belieben.",
"Das Laden des RecipeBERT-Modells war erfolgreich!"
]
}
def process_recipe_request_logic(required_ingredients, available_ingredients, max_ingredients, max_retries):
"""
Kernlogik zur Verarbeitung einer Rezeptgenerierungsanfrage.
Für diesen Test wird nur RecipeBERT zum Laden getestet und ein Mock-Rezept zurückgegeben.
"""
if not required_ingredients and not available_ingredients:
return {"error": "Keine Zutaten angegeben"}
try:
optimized_ingredients = find_best_ingredients(
required_ingredients, available_ingredients, max_ingredients
)
# Rufe die Mock-Generierungsfunktion auf
recipe = mock_generate_recipe(optimized_ingredients)
result = {
'title': recipe['title'],
'ingredients': recipe['ingredients'],
'directions': recipe['directions'],
'used_ingredients': optimized_ingredients
}
return result
except Exception as e:
return {"error": f"Fehler bei der Rezeptgenerierung: {str(e)}"}
# --- FastAPI-Implementierung ---
app = FastAPI(title="AI Recipe Generator API (RecipeBERT Only Test)")
class RecipeRequest(BaseModel):
required_ingredients: list[str] = []
available_ingredients: list[str] = []
max_ingredients: int = 7
max_retries: int = 5
ingredients: list[str] = [] # Für Abwärtskompatibilität
@app.post("/generate_recipe") # Der API-Endpunkt für Flutter
async def generate_recipe_api(request_data: RecipeRequest):
final_required_ingredients = request_data.required_ingredients
if not final_required_ingredients and request_data.ingredients:
final_required_ingredients = request_data.ingredients
result_dict = process_recipe_request_logic(
final_required_ingredients,
request_data.available_ingredients,
request_data.max_ingredients,
request_data.max_retries
)
return JSONResponse(content=result_dict)
@app.get("/")
async def read_root():
return {"message": "AI Recipe Generator API is running (RecipeBERT only)!"} # Angepasste Nachricht
print("INFO: FastAPI application script finished execution and defined 'app' variable.")