import torch from pathlib import Path from typing import Dict, Any, Optional from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier class ModelFactory: """ Factory class for creating and loading models """ @staticmethod def create_model( model_type: str, model_params: Dict[str, Any], state_dict_path: Optional[Path] = None ) -> torch.nn.Module: """ Create and load a model from configuration Args: model_type: Type of model ('Transformer', 'Mamba', 'LSTM') model_params: Dictionary of model parameters state_dict_path: Path to saved state dictionary Returns: Initialized PyTorch model Raises: ValueError: If model_type is unknown """ model_classes = { "Transformer": TransformerClassifier, "Mamba": MambaClassifier, "LSTM": LSTMClassifier } if model_type not in model_classes: raise ValueError(f"Unknown model type: {model_type}") model = model_classes[model_type](**model_params) if state_dict_path: state_dict = torch.load(state_dict_path, map_location="cpu") model.load_state_dict(state_dict) model.eval() return model