FlameF0X commited on
Commit
420490d
·
verified ·
1 Parent(s): 885e103

Create modeling_i3.py

Browse files
Files changed (1) hide show
  1. modeling_i3.py +85 -0
modeling_i3.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_i3.py
2
+ import os
3
+ import json
4
+ import torch
5
+ from torch import nn
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ from i3_model import i3Model, ChunkTokenizer
8
+
9
+ # ======================================================================
10
+ # I3 Configuration for Transformers
11
+ # ======================================================================
12
+ class I3Config(PretrainedConfig):
13
+ model_type = "i3"
14
+
15
+ def __init__(self, **kwargs):
16
+ super().__init__(**kwargs)
17
+
18
+ # ======================================================================
19
+ # I3 For Causal Language Modeling (HuggingFace Wrapper)
20
+ # ======================================================================
21
+ class I3ForCausalLM(PreTrainedModel):
22
+ config_class = I3Config
23
+ base_model_prefix = "i3"
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ self.i3 = i3Model(
28
+ vocab_size=config.vocab_size,
29
+ d_model=getattr(config, "d_model", 512),
30
+ n_heads=getattr(config, "n_heads", 16),
31
+ max_seq_len=getattr(config, "max_seq_len", 256),
32
+ d_state=getattr(config, "d_state", 32)
33
+ )
34
+ # Tokenizer reference (optional, for convenience)
35
+ self.tokenizer = None
36
+ self.post_init()
37
+
38
+ def forward(self, input_ids, labels=None):
39
+ logits, loss = self.i3(input_ids, targets=labels)
40
+ output = {"logits": logits}
41
+ if loss is not None:
42
+ output["loss"] = loss
43
+ return output
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
47
+ """
48
+ Load model weights and config from HF repo or local folder.
49
+ Also loads chunk tokenizer if present.
50
+ """
51
+ # Load config.json
52
+ config_path = os.path.join(pretrained_model_name_or_path, "config.json")
53
+ if not os.path.exists(config_path):
54
+ raise FileNotFoundError(f"Cannot find config.json at {config_path}")
55
+ with open(config_path, "r") as f:
56
+ config_dict = json.load(f)
57
+
58
+ config = I3Config(**config_dict)
59
+ model = cls(config)
60
+
61
+ # Load model weights
62
+ bin_path = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
63
+ safe_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
64
+
65
+ if os.path.exists(safe_path):
66
+ try:
67
+ import safetensors.torch
68
+ state_dict = safetensors.torch.load_file(safe_path)
69
+ model.load_state_dict(state_dict, strict=True)
70
+ except ImportError:
71
+ raise ImportError("Please install safetensors to load .safetensors files")
72
+ elif os.path.exists(bin_path):
73
+ state_dict = torch.load(bin_path, map_location="cpu")
74
+ model.load_state_dict(state_dict, strict=True)
75
+ else:
76
+ raise FileNotFoundError("No model file found in the provided path")
77
+
78
+ # Load tokenizer if chunk_vocab_combined.json exists
79
+ vocab_path = os.path.join(pretrained_model_name_or_path, "chunk_vocab_combined.json")
80
+ if os.path.exists(vocab_path):
81
+ tokenizer = ChunkTokenizer()
82
+ tokenizer.load(vocab_path)
83
+ model.tokenizer = tokenizer
84
+
85
+ return model