|
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
import uvicorn |
|
import numpy as np |
|
import tensorflow as tf |
|
from tensorflow.keras.models import load_model |
|
from PIL import Image |
|
import io |
|
import os |
|
from datetime import datetime |
|
import json |
|
from typing import List, Dict, Any |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
MODEL_PATH = "models/mega_ensemble_80" |
|
HUGGINGFACE_REPO = "bihan3876/my_model" |
|
CLASS_NAMES = ["๊ฐ๊ตฌ", "์ํ์ฉํ", "์ ์๊ธฐ๊ธฐ_๋์", "์ทจ๋ฏธ_๊ฒ์", "ํจ์
_๋ทฐํฐ"] |
|
IMG_SIZE = (224, 224) |
|
|
|
|
|
USE_HUGGINGFACE = os.getenv("USE_HUGGINGFACE", "false").lower() == "true" |
|
|
|
|
|
USE_LIGHTWEIGHT = os.getenv("USE_LIGHTWEIGHT", "false").lower() == "true" |
|
LIGHTWEIGHT_MODEL_PATH = "models/serving/model_optimized.tflite" |
|
|
|
|
|
app = FastAPI( |
|
title="AI ์ํ ๋ถ๋ฅ API", |
|
description="70.61% ์ ํ๋ ๋ฌ์ฑํ AI ๋ชจ๋ธ๋ก ์ค๊ณ ๊ฑฐ๋ ์ํ ์๋ ๋ถ๋ฅ", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
models = {} |
|
model_info = {} |
|
|
|
def download_from_huggingface(): |
|
"""Hugging Face Hub์์ ๋ชจ๋ธ ๋ค์ด๋ก๋""" |
|
try: |
|
from huggingface_hub import snapshot_download |
|
|
|
logger.info(f"Hugging Face์์ ๋ชจ๋ธ ๋ค์ด๋ก๋ ์ค: {HUGGINGFACE_REPO}") |
|
|
|
|
|
local_dir = snapshot_download( |
|
repo_id=HUGGINGFACE_REPO, |
|
cache_dir="./cache", |
|
local_dir="./models_hf" |
|
) |
|
|
|
logger.info(f"๋ชจ๋ธ ๋ค์ด๋ก๋ ์๋ฃ: {local_dir}") |
|
return local_dir |
|
|
|
except ImportError: |
|
logger.error("huggingface_hub ํจํค์ง๊ฐ ์ค์น๋์ง ์์์ต๋๋ค.") |
|
logger.error("์ค์น: pip install huggingface_hub") |
|
return None |
|
except Exception as e: |
|
logger.error(f"Hugging Face ๋ค์ด๋ก๋ ์คํจ: {e}") |
|
return None |
|
|
|
def load_lightweight_model(): |
|
"""๊ฒฝ๋ TensorFlow Lite ๋ชจ๋ธ ๋ก๋""" |
|
global models, model_info |
|
|
|
try: |
|
import tensorflow as tf |
|
|
|
|
|
if USE_HUGGINGFACE: |
|
|
|
local_dir = snapshot_download( |
|
repo_id=HUGGINGFACE_REPO, |
|
cache_dir="./cache", |
|
local_dir="./models_hf", |
|
allow_patterns=["models/serving/model_optimized.tflite"] |
|
) |
|
tflite_path = os.path.join(local_dir, "models", "serving", "model_optimized.tflite") |
|
else: |
|
tflite_path = LIGHTWEIGHT_MODEL_PATH |
|
|
|
if not os.path.exists(tflite_path): |
|
raise FileNotFoundError(f"TensorFlow Lite ๋ชจ๋ธ์ด ์์ต๋๋ค: {tflite_path}") |
|
|
|
|
|
interpreter = tf.lite.Interpreter(model_path=tflite_path) |
|
interpreter.allocate_tensors() |
|
|
|
models["tflite"] = interpreter |
|
|
|
model_info = { |
|
"total_models": 1, |
|
"model_names": ["TensorFlow_Lite"], |
|
"accuracy": 62.0, |
|
"classes": CLASS_NAMES, |
|
"input_shape": [224, 224, 3], |
|
"ensemble_method": "single_model" |
|
} |
|
|
|
logger.info(f"๐ ๊ฒฝ๋ ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ: TensorFlow Lite") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"โ ๊ฒฝ๋ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") |
|
return False |
|
|
|
def load_ensemble_models(): |
|
"""์์๋ธ ๋ชจ๋ธ๋ค ๋ก๋""" |
|
global models, model_info |
|
|
|
try: |
|
|
|
if USE_LIGHTWEIGHT: |
|
return load_lightweight_model() |
|
|
|
|
|
if USE_HUGGINGFACE: |
|
hf_path = download_from_huggingface() |
|
if hf_path: |
|
model_path = os.path.join(hf_path, "models", "ensemble") |
|
else: |
|
logger.warning("Hugging Face ๋ค์ด๋ก๋ ์คํจ, ๋ก์ปฌ ๋ชจ๋ธ ์ฌ์ฉ") |
|
model_path = MODEL_PATH |
|
else: |
|
model_path = MODEL_PATH |
|
|
|
logger.info(f"๋ชจ๋ธ ๋ก๋ฉ ์์: {model_path}") |
|
|
|
if not os.path.exists(model_path): |
|
raise FileNotFoundError(f"๋ชจ๋ธ ๊ฒฝ๋ก๊ฐ ์กด์ฌํ์ง ์์ต๋๋ค: {model_path}") |
|
|
|
|
|
model_files = [f for f in os.listdir(model_path) if f.endswith('.keras')] |
|
|
|
if not model_files: |
|
raise FileNotFoundError("๋ชจ๋ธ ํ์ผ์ ์ฐพ์ ์ ์์ต๋๋ค") |
|
|
|
|
|
for model_file in model_files: |
|
model_name = model_file.replace('.keras', '').replace('_best', '') |
|
model_path = os.path.join(model_path, model_file) |
|
|
|
try: |
|
model = load_model(model_path) |
|
models[model_name] = model |
|
logger.info(f"โ
{model_name} ๋ก๋ ์๋ฃ") |
|
except Exception as e: |
|
logger.warning(f"โ ๏ธ {model_name} ๋ก๋ ์คํจ: {e}") |
|
|
|
if not models: |
|
raise RuntimeError("๋ก๋๋ ๋ชจ๋ธ์ด ์์ต๋๋ค") |
|
|
|
|
|
model_info = { |
|
"total_models": len(models), |
|
"model_names": list(models.keys()), |
|
"accuracy": 70.61, |
|
"classes": CLASS_NAMES, |
|
"input_shape": [224, 224, 3], |
|
"ensemble_method": "soft_voting" |
|
} |
|
|
|
logger.info(f"๐ฏ ์์๋ธ ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ: {len(models)}๊ฐ ๋ชจ๋ธ") |
|
return True |
|
|
|
except Exception as e: |
|
logger.error(f"โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ: {e}") |
|
return False |
|
|
|
def preprocess_image(image_bytes: bytes) -> np.ndarray: |
|
"""์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ""" |
|
try: |
|
|
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
image = image.resize(IMG_SIZE) |
|
|
|
|
|
image_array = np.array(image) / 255.0 |
|
|
|
|
|
image_array = np.expand_dims(image_array, axis=0) |
|
|
|
return image_array |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=400, detail=f"์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ์คํจ: {str(e)}") |
|
|
|
def ensemble_predict(image_array: np.ndarray) -> Dict[str, Any]: |
|
"""์์๋ธ ์์ธก""" |
|
try: |
|
predictions = [] |
|
|
|
|
|
for model_name, model in models.items(): |
|
pred = model.predict(image_array, verbose=0) |
|
predictions.append(pred[0]) |
|
|
|
|
|
ensemble_pred = np.mean(predictions, axis=0) |
|
|
|
|
|
predicted_class_idx = np.argmax(ensemble_pred) |
|
predicted_class = CLASS_NAMES[predicted_class_idx] |
|
confidence = float(ensemble_pred[predicted_class_idx]) |
|
|
|
|
|
probabilities = { |
|
CLASS_NAMES[i]: float(ensemble_pred[i]) |
|
for i in range(len(CLASS_NAMES)) |
|
} |
|
|
|
return { |
|
"predicted_class": predicted_class, |
|
"confidence": confidence, |
|
"probabilities": probabilities, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"์์ธก ์คํจ: {str(e)}") |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""์๋ฒ ์์ ์ ๋ชจ๋ธ ๋ก๋""" |
|
logger.info("๐ AI ๋ถ๋ฅ ์๋ฒ ์์ ์ค...") |
|
|
|
if not load_ensemble_models(): |
|
logger.error("โ ๋ชจ๋ธ ๋ก๋ฉ ์คํจ๋ก ์๋ฒ ์์ ๋ถ๊ฐ") |
|
raise RuntimeError("๋ชจ๋ธ ๋ก๋ฉ ์คํจ") |
|
|
|
logger.info("โ
AI ๋ถ๋ฅ ์๋ฒ ์ค๋น ์๋ฃ!") |
|
|
|
@app.get("/") |
|
async def root(): |
|
"""๋ฃจํธ ์๋ํฌ์ธํธ""" |
|
return { |
|
"message": "AI ์ํ ๋ถ๋ฅ API", |
|
"version": "1.0.0", |
|
"accuracy": "70.61%", |
|
"docs": "/docs" |
|
} |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""ํฌ์ค ์ฒดํฌ""" |
|
return { |
|
"status": "healthy", |
|
"models_loaded": len(models), |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
@app.get("/model-info") |
|
async def get_model_info(): |
|
"""๋ชจ๋ธ ์ ๋ณด ์กฐํ""" |
|
return { |
|
"success": True, |
|
"data": model_info |
|
} |
|
|
|
@app.get("/classes") |
|
async def get_classes(): |
|
"""์ง์ ํด๋์ค ๋ชฉ๋ก""" |
|
return { |
|
"success": True, |
|
"classes": CLASS_NAMES, |
|
"total_classes": len(CLASS_NAMES) |
|
} |
|
|
|
@app.post("/predict") |
|
async def predict_image( |
|
file: UploadFile = File(...), |
|
return_probabilities: bool = True |
|
): |
|
"""๋จ์ผ ์ด๋ฏธ์ง ๋ถ๋ฅ""" |
|
try: |
|
|
|
if not file.content_type.startswith('image/'): |
|
raise HTTPException(status_code=400, detail="์ด๋ฏธ์ง ํ์ผ๋ง ์
๋ก๋ ๊ฐ๋ฅํฉ๋๋ค") |
|
|
|
|
|
image_bytes = await file.read() |
|
|
|
|
|
image_array = preprocess_image(image_bytes) |
|
|
|
|
|
result = ensemble_predict(image_array) |
|
|
|
|
|
response_data = { |
|
"predicted_class": result["predicted_class"], |
|
"confidence": result["confidence"], |
|
"timestamp": result["timestamp"] |
|
} |
|
|
|
if return_probabilities: |
|
response_data["probabilities"] = result["probabilities"] |
|
|
|
return { |
|
"success": True, |
|
"data": response_data, |
|
"message": "๋ถ๋ฅ ์๋ฃ" |
|
} |
|
|
|
except HTTPException: |
|
raise |
|
except Exception as e: |
|
logger.error(f"์์ธก ์ค๋ฅ: {e}") |
|
return JSONResponse( |
|
status_code=500, |
|
content={ |
|
"success": False, |
|
"message": f"์์ธก ์คํจ: {str(e)}" |
|
} |
|
) |
|
|
|
@app.post("/batch-predict") |
|
async def batch_predict_images( |
|
files: List[UploadFile] = File(...), |
|
return_probabilities: bool = True |
|
): |
|
"""๋ฐฐ์น ์ด๋ฏธ์ง ๋ถ๋ฅ""" |
|
try: |
|
results = [] |
|
|
|
for i, file in enumerate(files): |
|
try: |
|
if not file.content_type.startswith('image/'): |
|
results.append({ |
|
"filename": file.filename, |
|
"success": False, |
|
"message": "์ด๋ฏธ์ง ํ์ผ์ด ์๋๋๋ค" |
|
}) |
|
continue |
|
|
|
|
|
image_bytes = await file.read() |
|
image_array = preprocess_image(image_bytes) |
|
result = ensemble_predict(image_array) |
|
|
|
|
|
batch_result = { |
|
"filename": file.filename, |
|
"success": True, |
|
"predicted_class": result["predicted_class"], |
|
"confidence": result["confidence"] |
|
} |
|
|
|
if return_probabilities: |
|
batch_result["probabilities"] = result["probabilities"] |
|
|
|
results.append(batch_result) |
|
|
|
except Exception as e: |
|
results.append({ |
|
"filename": file.filename, |
|
"success": False, |
|
"message": str(e) |
|
}) |
|
|
|
return { |
|
"success": True, |
|
"total_files": len(files), |
|
"results": results, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"๋ฐฐ์น ์์ธก ์ค๋ฅ: {e}") |
|
return JSONResponse( |
|
status_code=500, |
|
content={ |
|
"success": False, |
|
"message": f"๋ฐฐ์น ์์ธก ์คํจ: {str(e)}" |
|
} |
|
) |
|
|
|
if __name__ == "__main__": |
|
print("๐ AI ์ํ ๋ถ๋ฅ ์๋ฒ ์์!") |
|
print("๐ ๋ชจ๋ธ: 70.61% ๋ฉ๊ฐ ์์๋ธ") |
|
print("๐ API ๋ฌธ์: http://localhost:8000/docs") |
|
print("๐ ํฌ์ค ์ฒดํฌ: http://localhost:8000/health") |
|
|
|
uvicorn.run( |
|
app, |
|
host="0.0.0.0", |
|
port=8000, |
|
log_level="info" |
|
) |
|
|