Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
from torch.nn.functional import silu
|
7 |
+
from torch.nn.functional import softplus
|
8 |
+
from einops import rearrange, repeat, einsum
|
9 |
+
from transformers import AutoTokenizer, AutoModel
|
10 |
+
from torch import Tensor
|
11 |
+
from einops import rearrange
|
12 |
+
|
13 |
+
class Embedding():
|
14 |
+
def __init__(self, model_name='jina', pooling=None):
|
15 |
+
self.model_name = model_name
|
16 |
+
self.pooling = pooling
|
17 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
18 |
+
if model_name == 'jina':
|
19 |
+
self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True)
|
20 |
+
self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True).to(self.device)
|
21 |
+
elif model_name == 'xlm-roberta-base':
|
22 |
+
self.tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
|
23 |
+
self.model = AutoModel.from_pretrained('xlm-roberta-base').to(self.device)
|
24 |
+
elif model_name == 'canine-c':
|
25 |
+
self.tokenizer = AutoTokenizer.from_pretrained('google/canine-c')
|
26 |
+
self.model = AutoModel.from_pretrained('google/canine-c').to(self.device)
|
27 |
+
else:
|
28 |
+
raise ValueError('Unknown name of Embedding')
|
29 |
+
def _mean_pooling(self, X):
|
30 |
+
def mean_pooling(model_output, attention_mask):
|
31 |
+
token_embeddings = model_output[0]
|
32 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
33 |
+
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
34 |
+
encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
35 |
+
with torch.no_grad():
|
36 |
+
model_output = self.model(**encoded_input)
|
37 |
+
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
|
38 |
+
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
|
39 |
+
return sentence_embeddings.unsqueeze(1)
|
40 |
+
|
41 |
+
def get_embeddings(self, X):
|
42 |
+
if self.pooling is None:
|
43 |
+
if self.model_name == 'canine-c_emb':
|
44 |
+
max_len = 329
|
45 |
+
else:
|
46 |
+
max_len = 95
|
47 |
+
encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
|
48 |
+
with torch.no_grad():
|
49 |
+
features = self.model(**encoded_input)[0].detach().cpu().float().numpy()
|
50 |
+
res = np.pad(features[:, :max_len, :], ((0, 0), (0, max(0, max_len - features.shape[1])), (0, 0)), "constant")
|
51 |
+
return torch.tensor(res)
|
52 |
+
elif self.pooling == 'mean':
|
53 |
+
return self._mean_pooling(X)
|
54 |
+
else:
|
55 |
+
raise ValueError('Unknown type of pooling')
|
56 |
+
class RMSNorm(nn.Module):
|
57 |
+
def __init__(self, d_model: int, eps: float = 1e-8) -> None:
|
58 |
+
super().__init__()
|
59 |
+
self.eps = eps
|
60 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
61 |
+
|
62 |
+
def forward(self, x: Tensor) -> Tensor:
|
63 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
|
64 |
+
|
65 |
+
class Mamba(nn.Module):
|
66 |
+
def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, model_name='jina', pooling=None):
|
67 |
+
super().__init__()
|
68 |
+
mamba_par = {
|
69 |
+
'd_input' : d_input,
|
70 |
+
'd_model' : d_model,
|
71 |
+
'd_state' : d_state,
|
72 |
+
'd_discr' : d_discr,
|
73 |
+
'ker_size': ker_size
|
74 |
+
}
|
75 |
+
self.model_name = model_name
|
76 |
+
embed = Embedding(model_name, pooling)
|
77 |
+
self.embedding = embed.get_embeddings
|
78 |
+
self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
|
79 |
+
self.fc_out = nn.Linear(d_input, num_classes)
|
80 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
81 |
+
self.softmax = nn.Softmax(dim=1)
|
82 |
+
|
83 |
+
def forward(self, seq, cache=None):
|
84 |
+
seq = torch.tensor(self.embedding(seq)).to(self.device)
|
85 |
+
for mamba, norm in self.layers:
|
86 |
+
out, cache = mamba(norm(seq), cache)
|
87 |
+
seq = out + seq
|
88 |
+
return self.fc_out(seq.mean(dim = 1))
|
89 |
+
|
90 |
+
def predict(self, x):
|
91 |
+
label_to_emotion = {
|
92 |
+
0: 'anger',
|
93 |
+
1: 'disgust',
|
94 |
+
2: 'fear',
|
95 |
+
3: 'joy/happiness',
|
96 |
+
4: 'neutral',
|
97 |
+
5: 'sadness',
|
98 |
+
6: 'surprise/enthusiasm'
|
99 |
+
}
|
100 |
+
with torch.no_grad():
|
101 |
+
output = self.forward(x)
|
102 |
+
_, predictions = torch.max(output, dim=1)
|
103 |
+
result = [label_to_emotion[i] for i in (map(int, predictions))]
|
104 |
+
return result
|
105 |
+
|
106 |
+
def predict_proba(self, x):
|
107 |
+
with torch.no_grad():
|
108 |
+
output = self.forward(x)
|
109 |
+
#_, predictions = torch.max(output, dim=1)
|
110 |
+
return self.softmax(output)
|
111 |
+
|
112 |
+
class MambaBlock(nn.Module):
|
113 |
+
def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
|
114 |
+
super().__init__()
|
115 |
+
d_discr = d_discr if d_discr is not None else d_model // 16
|
116 |
+
self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
|
117 |
+
self.out_proj = nn.Linear(d_model, d_input, bias=False)
|
118 |
+
self.s_B = nn.Linear(d_model, d_state, bias=False)
|
119 |
+
self.s_C = nn.Linear(d_model, d_state, bias=False)
|
120 |
+
self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
|
121 |
+
self.conv = nn.Conv1d(
|
122 |
+
in_channels=d_model,
|
123 |
+
out_channels=d_model,
|
124 |
+
kernel_size=ker_size,
|
125 |
+
padding=ker_size - 1,
|
126 |
+
groups=d_model,
|
127 |
+
bias=True,
|
128 |
+
)
|
129 |
+
self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
|
130 |
+
self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
|
131 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
132 |
+
|
133 |
+
def forward(self, seq, cache=None):
|
134 |
+
b, l, d = seq.shape
|
135 |
+
(prev_hid, prev_inp) = cache if cache is not None else (None, None)
|
136 |
+
a, b = self.in_proj(seq).chunk(2, dim=-1)
|
137 |
+
x = rearrange(a, 'b l d -> b d l')
|
138 |
+
x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
|
139 |
+
a = self.conv(x)[..., :l]
|
140 |
+
a = rearrange(a, 'b d l -> b l d')
|
141 |
+
a = silu(a)
|
142 |
+
a, hid = self.ssm(a, prev_hid=prev_hid)
|
143 |
+
b = silu(b)
|
144 |
+
out = a * b
|
145 |
+
out = self.out_proj(out)
|
146 |
+
if cache:
|
147 |
+
cache = (hid.squeeze(), x[..., 1:])
|
148 |
+
return out, cache
|
149 |
+
|
150 |
+
def ssm(self, seq, prev_hid):
|
151 |
+
A = -self.A
|
152 |
+
D = +self.D
|
153 |
+
B = self.s_B(seq)
|
154 |
+
C = self.s_C(seq)
|
155 |
+
s = softplus(D + self.s_D(seq))
|
156 |
+
A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
|
157 |
+
B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
|
158 |
+
X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
|
159 |
+
hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
|
160 |
+
out = einsum(hid, C, 'b l d s, b l s -> b l d')
|
161 |
+
out = out + D * seq
|
162 |
+
return out, hid
|
163 |
+
|
164 |
+
def _hid_states(self, A, X, prev_hid=None):
|
165 |
+
b, l, d, s = A.shape
|
166 |
+
A = rearrange(A, 'b l d s -> l b d s')
|
167 |
+
X = rearrange(X, 'b l d s -> l b d s')
|
168 |
+
if prev_hid is not None:
|
169 |
+
return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
|
170 |
+
h = torch.zeros(b, d, s, device=self.device)
|
171 |
+
return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)
|