File size: 1,418 Bytes
cd123bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch

from pathlib import Path
from typing import Dict, Any, Optional

from src.models.models import TransformerClassifier, MambaClassifier, LSTMClassifier


class ModelFactory:
    """
    Factory class for creating and loading models
    """
    
    @staticmethod
    def create_model(
        model_type: str,
        model_params: Dict[str, Any],
        state_dict_path: Optional[Path] = None
    ) -> torch.nn.Module:
        """
        Create and load a model from configuration
        
        Args:
            model_type: Type of model ('Transformer', 'Mamba', 'LSTM')
            model_params: Dictionary of model parameters
            state_dict_path: Path to saved state dictionary
            
        Returns:
            Initialized PyTorch model
            
        Raises:
            ValueError: If model_type is unknown
        """

        model_classes = {
            "Transformer": TransformerClassifier,
            "Mamba": MambaClassifier,
            "LSTM": LSTMClassifier
        }
        
        if model_type not in model_classes:
            raise ValueError(f"Unknown model type: {model_type}")
        
        model = model_classes[model_type](**model_params)
        
        if state_dict_path:
            state_dict = torch.load(state_dict_path, map_location="cpu")
            model.load_state_dict(state_dict)
        
        model.eval()
        return model