# 역할: 훈련된 모델을 실제 서비스로 제공하는 API 서버 # POST /predict - 이미지 분류 # GET /health - 서버 상태 확인 # 70.61% 성능을 위해서는 앙상블이 필수 # Python API 서버로만 가능 from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import uvicorn import numpy as np import tensorflow as tf from tensorflow.keras.models import load_model from PIL import Image import io import os from datetime import datetime import json from typing import List, Dict, Any import logging # 로깅 설정 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 설정 MODEL_PATH = "models/mega_ensemble_80" # 로컬 모델 경로 HUGGINGFACE_REPO = "bihan3876/my_model" # Hugging Face 저장소 CLASS_NAMES = ["가구", "생활용품", "전자기기_도서", "취미_게임", "패션_뷰티"] IMG_SIZE = (224, 224) # Hugging Face Hub 사용 여부 USE_HUGGINGFACE = os.getenv("USE_HUGGINGFACE", "false").lower() == "true" # 경량 모드 사용 여부 (TensorFlow Lite 모델 사용) USE_LIGHTWEIGHT = os.getenv("USE_LIGHTWEIGHT", "false").lower() == "true" LIGHTWEIGHT_MODEL_PATH = "models/serving/model_optimized.tflite" # FastAPI 앱 생성 app = FastAPI( title="AI 상품 분류 API", description="70.61% 정확도 달성한 AI 모델로 중고거래 상품 자동 분류", version="1.0.0" ) # CORS 설정 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 전역 변수 models = {} model_info = {} def download_from_huggingface(): """Hugging Face Hub에서 모델 다운로드""" try: from huggingface_hub import snapshot_download logger.info(f"Hugging Face에서 모델 다운로드 중: {HUGGINGFACE_REPO}") # 앙상블 모델만 다운로드 (349MB) local_dir = snapshot_download( repo_id=HUGGINGFACE_REPO, cache_dir="./cache", local_dir="./models_hf" ) logger.info(f"모델 다운로드 완료: {local_dir}") return local_dir except ImportError: logger.error("huggingface_hub 패키지가 설치되지 않았습니다.") logger.error("설치: pip install huggingface_hub") return None except Exception as e: logger.error(f"Hugging Face 다운로드 실패: {e}") return None def load_lightweight_model(): """경량 TensorFlow Lite 모델 로드""" global models, model_info try: import tensorflow as tf # TensorFlow Lite 모델 경로 if USE_HUGGINGFACE: # Hugging Face에서 경량 모델만 다운로드 (24MB) local_dir = snapshot_download( repo_id=HUGGINGFACE_REPO, cache_dir="./cache", local_dir="./models_hf", allow_patterns=["models/serving/model_optimized.tflite"] ) tflite_path = os.path.join(local_dir, "models", "serving", "model_optimized.tflite") else: tflite_path = LIGHTWEIGHT_MODEL_PATH if not os.path.exists(tflite_path): raise FileNotFoundError(f"TensorFlow Lite 모델이 없습니다: {tflite_path}") # TensorFlow Lite 인터프리터 로드 interpreter = tf.lite.Interpreter(model_path=tflite_path) interpreter.allocate_tensors() models["tflite"] = interpreter model_info = { "total_models": 1, "model_names": ["TensorFlow_Lite"], "accuracy": 62.0, # 추정 성능 "classes": CLASS_NAMES, "input_shape": [224, 224, 3], "ensemble_method": "single_model" } logger.info(f"🚀 경량 모델 로딩 완료: TensorFlow Lite") return True except Exception as e: logger.error(f"❌ 경량 모델 로딩 실패: {e}") return False def load_ensemble_models(): """앙상블 모델들 로드""" global models, model_info try: # 경량 모드 사용 시 if USE_LIGHTWEIGHT: return load_lightweight_model() # Hugging Face 사용 시 모델 다운로드 if USE_HUGGINGFACE: hf_path = download_from_huggingface() if hf_path: model_path = os.path.join(hf_path, "models", "ensemble") else: logger.warning("Hugging Face 다운로드 실패, 로컬 모델 사용") model_path = MODEL_PATH else: model_path = MODEL_PATH logger.info(f"모델 로딩 시작: {model_path}") if not os.path.exists(model_path): raise FileNotFoundError(f"모델 경로가 존재하지 않습니다: {model_path}") # 모델 파일들 찾기 model_files = [f for f in os.listdir(model_path) if f.endswith('.keras')] if not model_files: raise FileNotFoundError("모델 파일을 찾을 수 없습니다") # 각 모델 로드 for model_file in model_files: model_name = model_file.replace('.keras', '').replace('_best', '') model_path = os.path.join(model_path, model_file) try: model = load_model(model_path) models[model_name] = model logger.info(f"✅ {model_name} 로드 완료") except Exception as e: logger.warning(f"⚠️ {model_name} 로드 실패: {e}") if not models: raise RuntimeError("로드된 모델이 없습니다") # 모델 정보 설정 model_info = { "total_models": len(models), "model_names": list(models.keys()), "accuracy": 70.61, "classes": CLASS_NAMES, "input_shape": [224, 224, 3], "ensemble_method": "soft_voting" } logger.info(f"🎯 앙상블 모델 로딩 완료: {len(models)}개 모델") return True except Exception as e: logger.error(f"❌ 모델 로딩 실패: {e}") return False def preprocess_image(image_bytes: bytes) -> np.ndarray: """이미지 전처리""" try: # PIL로 이미지 열기 image = Image.open(io.BytesIO(image_bytes)) # RGB 변환 if image.mode != 'RGB': image = image.convert('RGB') # 크기 조정 image = image.resize(IMG_SIZE) # numpy 배열로 변환 및 정규화 image_array = np.array(image) / 255.0 # 배치 차원 추가 image_array = np.expand_dims(image_array, axis=0) return image_array except Exception as e: raise HTTPException(status_code=400, detail=f"이미지 전처리 실패: {str(e)}") def ensemble_predict(image_array: np.ndarray) -> Dict[str, Any]: """앙상블 예측""" try: predictions = [] # 각 모델로 예측 for model_name, model in models.items(): pred = model.predict(image_array, verbose=0) predictions.append(pred[0]) # 소프트 보팅 (평균) ensemble_pred = np.mean(predictions, axis=0) # 결과 처리 predicted_class_idx = np.argmax(ensemble_pred) predicted_class = CLASS_NAMES[predicted_class_idx] confidence = float(ensemble_pred[predicted_class_idx]) # 각 클래스별 확률 probabilities = { CLASS_NAMES[i]: float(ensemble_pred[i]) for i in range(len(CLASS_NAMES)) } return { "predicted_class": predicted_class, "confidence": confidence, "probabilities": probabilities, "timestamp": datetime.now().isoformat() } except Exception as e: raise HTTPException(status_code=500, detail=f"예측 실패: {str(e)}") @app.on_event("startup") async def startup_event(): """서버 시작 시 모델 로드""" logger.info("🚀 AI 분류 서버 시작 중...") if not load_ensemble_models(): logger.error("❌ 모델 로딩 실패로 서버 시작 불가") raise RuntimeError("모델 로딩 실패") logger.info("✅ AI 분류 서버 준비 완료!") @app.get("/") async def root(): """루트 엔드포인트""" return { "message": "AI 상품 분류 API", "version": "1.0.0", "accuracy": "70.61%", "docs": "/docs" } @app.get("/health") async def health_check(): """헬스 체크""" return { "status": "healthy", "models_loaded": len(models), "timestamp": datetime.now().isoformat() } @app.get("/model-info") async def get_model_info(): """모델 정보 조회""" return { "success": True, "data": model_info } @app.get("/classes") async def get_classes(): """지원 클래스 목록""" return { "success": True, "classes": CLASS_NAMES, "total_classes": len(CLASS_NAMES) } @app.post("/predict") async def predict_image( file: UploadFile = File(...), return_probabilities: bool = True ): """단일 이미지 분류""" try: # 파일 검증 if not file.content_type.startswith('image/'): raise HTTPException(status_code=400, detail="이미지 파일만 업로드 가능합니다") # 이미지 읽기 image_bytes = await file.read() # 전처리 image_array = preprocess_image(image_bytes) # 예측 result = ensemble_predict(image_array) # 응답 구성 response_data = { "predicted_class": result["predicted_class"], "confidence": result["confidence"], "timestamp": result["timestamp"] } if return_probabilities: response_data["probabilities"] = result["probabilities"] return { "success": True, "data": response_data, "message": "분류 완료" } except HTTPException: raise except Exception as e: logger.error(f"예측 오류: {e}") return JSONResponse( status_code=500, content={ "success": False, "message": f"예측 실패: {str(e)}" } ) @app.post("/batch-predict") async def batch_predict_images( files: List[UploadFile] = File(...), return_probabilities: bool = True ): """배치 이미지 분류""" try: results = [] for i, file in enumerate(files): try: if not file.content_type.startswith('image/'): results.append({ "filename": file.filename, "success": False, "message": "이미지 파일이 아닙니다" }) continue # 이미지 처리 image_bytes = await file.read() image_array = preprocess_image(image_bytes) result = ensemble_predict(image_array) # 결과 추가 batch_result = { "filename": file.filename, "success": True, "predicted_class": result["predicted_class"], "confidence": result["confidence"] } if return_probabilities: batch_result["probabilities"] = result["probabilities"] results.append(batch_result) except Exception as e: results.append({ "filename": file.filename, "success": False, "message": str(e) }) return { "success": True, "total_files": len(files), "results": results, "timestamp": datetime.now().isoformat() } except Exception as e: logger.error(f"배치 예측 오류: {e}") return JSONResponse( status_code=500, content={ "success": False, "message": f"배치 예측 실패: {str(e)}" } ) if __name__ == "__main__": print("🚀 AI 상품 분류 서버 시작!") print("📊 모델: 70.61% 메가 앙상블") print("🌐 API 문서: http://localhost:8000/docs") print("🔍 헬스 체크: http://localhost:8000/health") uvicorn.run( app, host="0.0.0.0", port=8000, log_level="info" )