dashakoryakovskaya commited on
Commit
f526539
·
verified ·
1 Parent(s): 2dc2b87

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +171 -0
model.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import pandas as pd
6
+ from torch.nn.functional import silu
7
+ from torch.nn.functional import softplus
8
+ from einops import rearrange, repeat, einsum
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from torch import Tensor
11
+ from einops import rearrange
12
+
13
+ class Embedding():
14
+ def __init__(self, model_name='jina', pooling=None):
15
+ self.model_name = model_name
16
+ self.pooling = pooling
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ if model_name == 'jina':
19
+ self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True)
20
+ self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True).to(self.device)
21
+ elif model_name == 'xlm-roberta-base':
22
+ self.tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
23
+ self.model = AutoModel.from_pretrained('xlm-roberta-base').to(self.device)
24
+ elif model_name == 'canine-c':
25
+ self.tokenizer = AutoTokenizer.from_pretrained('google/canine-c')
26
+ self.model = AutoModel.from_pretrained('google/canine-c').to(self.device)
27
+ else:
28
+ raise ValueError('Unknown name of Embedding')
29
+ def _mean_pooling(self, X):
30
+ def mean_pooling(model_output, attention_mask):
31
+ token_embeddings = model_output[0]
32
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
33
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
34
+ encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
35
+ with torch.no_grad():
36
+ model_output = self.model(**encoded_input)
37
+ sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
38
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
39
+ return sentence_embeddings.unsqueeze(1)
40
+
41
+ def get_embeddings(self, X):
42
+ if self.pooling is None:
43
+ if self.model_name == 'canine-c_emb':
44
+ max_len = 329
45
+ else:
46
+ max_len = 95
47
+ encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
48
+ with torch.no_grad():
49
+ features = self.model(**encoded_input)[0].detach().cpu().float().numpy()
50
+ res = np.pad(features[:, :max_len, :], ((0, 0), (0, max(0, max_len - features.shape[1])), (0, 0)), "constant")
51
+ return torch.tensor(res)
52
+ elif self.pooling == 'mean':
53
+ return self._mean_pooling(X)
54
+ else:
55
+ raise ValueError('Unknown type of pooling')
56
+ class RMSNorm(nn.Module):
57
+ def __init__(self, d_model: int, eps: float = 1e-8) -> None:
58
+ super().__init__()
59
+ self.eps = eps
60
+ self.weight = nn.Parameter(torch.ones(d_model))
61
+
62
+ def forward(self, x: Tensor) -> Tensor:
63
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
64
+
65
+ class Mamba(nn.Module):
66
+ def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, model_name='jina', pooling=None):
67
+ super().__init__()
68
+ mamba_par = {
69
+ 'd_input' : d_input,
70
+ 'd_model' : d_model,
71
+ 'd_state' : d_state,
72
+ 'd_discr' : d_discr,
73
+ 'ker_size': ker_size
74
+ }
75
+ self.model_name = model_name
76
+ embed = Embedding(model_name, pooling)
77
+ self.embedding = embed.get_embeddings
78
+ self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
79
+ self.fc_out = nn.Linear(d_input, num_classes)
80
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ self.softmax = nn.Softmax(dim=1)
82
+
83
+ def forward(self, seq, cache=None):
84
+ seq = torch.tensor(self.embedding(seq)).to(self.device)
85
+ for mamba, norm in self.layers:
86
+ out, cache = mamba(norm(seq), cache)
87
+ seq = out + seq
88
+ return self.fc_out(seq.mean(dim = 1))
89
+
90
+ def predict(self, x):
91
+ label_to_emotion = {
92
+ 0: 'anger',
93
+ 1: 'disgust',
94
+ 2: 'fear',
95
+ 3: 'joy/happiness',
96
+ 4: 'neutral',
97
+ 5: 'sadness',
98
+ 6: 'surprise/enthusiasm'
99
+ }
100
+ with torch.no_grad():
101
+ output = self.forward(x)
102
+ _, predictions = torch.max(output, dim=1)
103
+ result = [label_to_emotion[i] for i in (map(int, predictions))]
104
+ return result
105
+
106
+ def predict_proba(self, x):
107
+ with torch.no_grad():
108
+ output = self.forward(x)
109
+ #_, predictions = torch.max(output, dim=1)
110
+ return self.softmax(output)
111
+
112
+ class MambaBlock(nn.Module):
113
+ def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
114
+ super().__init__()
115
+ d_discr = d_discr if d_discr is not None else d_model // 16
116
+ self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
117
+ self.out_proj = nn.Linear(d_model, d_input, bias=False)
118
+ self.s_B = nn.Linear(d_model, d_state, bias=False)
119
+ self.s_C = nn.Linear(d_model, d_state, bias=False)
120
+ self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
121
+ self.conv = nn.Conv1d(
122
+ in_channels=d_model,
123
+ out_channels=d_model,
124
+ kernel_size=ker_size,
125
+ padding=ker_size - 1,
126
+ groups=d_model,
127
+ bias=True,
128
+ )
129
+ self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
130
+ self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
131
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+
133
+ def forward(self, seq, cache=None):
134
+ b, l, d = seq.shape
135
+ (prev_hid, prev_inp) = cache if cache is not None else (None, None)
136
+ a, b = self.in_proj(seq).chunk(2, dim=-1)
137
+ x = rearrange(a, 'b l d -> b d l')
138
+ x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
139
+ a = self.conv(x)[..., :l]
140
+ a = rearrange(a, 'b d l -> b l d')
141
+ a = silu(a)
142
+ a, hid = self.ssm(a, prev_hid=prev_hid)
143
+ b = silu(b)
144
+ out = a * b
145
+ out = self.out_proj(out)
146
+ if cache:
147
+ cache = (hid.squeeze(), x[..., 1:])
148
+ return out, cache
149
+
150
+ def ssm(self, seq, prev_hid):
151
+ A = -self.A
152
+ D = +self.D
153
+ B = self.s_B(seq)
154
+ C = self.s_C(seq)
155
+ s = softplus(D + self.s_D(seq))
156
+ A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
157
+ B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
158
+ X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
159
+ hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
160
+ out = einsum(hid, C, 'b l d s, b l s -> b l d')
161
+ out = out + D * seq
162
+ return out, hid
163
+
164
+ def _hid_states(self, A, X, prev_hid=None):
165
+ b, l, d, s = A.shape
166
+ A = rearrange(A, 'b l d s -> l b d s')
167
+ X = rearrange(X, 'b l d s -> l b d s')
168
+ if prev_hid is not None:
169
+ return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
170
+ h = torch.zeros(b, d, s, device=self.device)
171
+ return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)