|
import json |
|
import torch |
|
|
|
from pathlib import Path |
|
from typing import Dict, Any |
|
|
|
from src.app.model_utils.factory import ModelFactory |
|
|
|
|
|
class ModelManager: |
|
""" |
|
Manages model loading and inference operations |
|
|
|
Args: |
|
model_dir: Directory containing model artifacts |
|
""" |
|
|
|
def __init__(self, model_dir: str = "../pretrained") -> None: |
|
self.model_dir = Path(model_dir) |
|
self.loaded_models: Dict[str, Any] = {} |
|
self._load_model_artifacts() |
|
|
|
|
|
def _load_model_artifacts(self) -> None: |
|
""" |
|
Load model configuration and vocabulary |
|
""" |
|
|
|
with open(self.model_dir / "config.json", "r") as f: |
|
self.config = json.load(f) |
|
|
|
with open(self.model_dir / "vocab.json", "r") as f: |
|
self.vocab = json.load(f) |
|
|
|
self.idx_to_label = {0: "Negative", 1: "Positive"} |
|
|
|
|
|
def get_model(self) -> torch.nn.Module: |
|
""" |
|
Get the loaded model (cached for performance) |
|
|
|
Returns: |
|
Loaded PyTorch model in evaluation mode |
|
""" |
|
|
|
model_type = self.config["model_type"] |
|
|
|
if model_type not in self.loaded_models: |
|
model = ModelFactory.create_model( |
|
model_type=model_type, |
|
model_params=self.config["model_params"], |
|
state_dict_path=self.model_dir / "best_model.pth" |
|
) |
|
self.loaded_models[model_type] = model |
|
|
|
return self.loaded_models[model_type] |
|
|
|
|
|
def get_vocab(self) -> Dict[str, int]: |
|
""" |
|
Get vocabulary mapping |
|
""" |
|
|
|
return self.vocab |
|
|
|
|
|
def get_config(self) -> Dict[str, Any]: |
|
""" |
|
Get model configuration |
|
""" |
|
|
|
return self.config |
|
|