litvinovmitch11's picture
Synced repo using 'sync_with_huggingface' Github Action
cd123bf verified
import torch
from pathlib import Path
from typing import Dict, Any, Optional
from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier
class ModelFactory:
"""
Factory class for creating and loading models
"""
@staticmethod
def create_model(
model_type: str,
model_params: Dict[str, Any],
state_dict_path: Optional[Path] = None
) -> torch.nn.Module:
"""
Create and load a model from configuration
Args:
model_type: Type of model ('Transformer', 'Mamba', 'LSTM')
model_params: Dictionary of model parameters
state_dict_path: Path to saved state dictionary
Returns:
Initialized PyTorch model
Raises:
ValueError: If model_type is unknown
"""
model_classes = {
"Transformer": TransformerClassifier,
"Mamba": MambaClassifier,
"LSTM": LSTMClassifier
}
if model_type not in model_classes:
raise ValueError(f"Unknown model type: {model_type}")
model = model_classes[model_type](**model_params)
if state_dict_path:
state_dict = torch.load(state_dict_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model