dashakoryakovskaya's picture
Create model.py
f526539 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch.nn.functional import silu
from torch.nn.functional import softplus
from einops import rearrange, repeat, einsum
from transformers import AutoTokenizer, AutoModel
from torch import Tensor
from einops import rearrange
class Embedding():
def __init__(self, model_name='jina', pooling=None):
self.model_name = model_name
self.pooling = pooling
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_name == 'jina':
self.tokenizer = AutoTokenizer.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True)
self.model = AutoModel.from_pretrained("jinaai/jina-embeddings-v3", code_revision='da863dd04a4e5dce6814c6625adfba87b83838aa', trust_remote_code=True).to(self.device)
elif model_name == 'xlm-roberta-base':
self.tokenizer = AutoTokenizer.from_pretrained('xlm-roberta-base')
self.model = AutoModel.from_pretrained('xlm-roberta-base').to(self.device)
elif model_name == 'canine-c':
self.tokenizer = AutoTokenizer.from_pretrained('google/canine-c')
self.model = AutoModel.from_pretrained('google/canine-c').to(self.device)
else:
raise ValueError('Unknown name of Embedding')
def _mean_pooling(self, X):
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
with torch.no_grad():
model_output = self.model(**encoded_input)
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
return sentence_embeddings.unsqueeze(1)
def get_embeddings(self, X):
if self.pooling is None:
if self.model_name == 'canine-c_emb':
max_len = 329
else:
max_len = 95
encoded_input = self.tokenizer(X, padding=True, truncation=True, return_tensors='pt').to(self.device)
with torch.no_grad():
features = self.model(**encoded_input)[0].detach().cpu().float().numpy()
res = np.pad(features[:, :max_len, :], ((0, 0), (0, max(0, max_len - features.shape[1])), (0, 0)), "constant")
return torch.tensor(res)
elif self.pooling == 'mean':
return self._mean_pooling(X)
else:
raise ValueError('Unknown type of pooling')
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-8) -> None:
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x: Tensor) -> Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim = True) + self.eps) * self.weight
class Mamba(nn.Module):
def __init__(self, num_layers, d_input, d_model, d_state=16, d_discr=None, ker_size=4, num_classes=7, model_name='jina', pooling=None):
super().__init__()
mamba_par = {
'd_input' : d_input,
'd_model' : d_model,
'd_state' : d_state,
'd_discr' : d_discr,
'ker_size': ker_size
}
self.model_name = model_name
embed = Embedding(model_name, pooling)
self.embedding = embed.get_embeddings
self.layers = nn.ModuleList([nn.ModuleList([MambaBlock(**mamba_par), RMSNorm(d_input)]) for _ in range(num_layers)])
self.fc_out = nn.Linear(d_input, num_classes)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.softmax = nn.Softmax(dim=1)
def forward(self, seq, cache=None):
seq = torch.tensor(self.embedding(seq)).to(self.device)
for mamba, norm in self.layers:
out, cache = mamba(norm(seq), cache)
seq = out + seq
return self.fc_out(seq.mean(dim = 1))
def predict(self, x):
label_to_emotion = {
0: 'anger',
1: 'disgust',
2: 'fear',
3: 'joy/happiness',
4: 'neutral',
5: 'sadness',
6: 'surprise/enthusiasm'
}
with torch.no_grad():
output = self.forward(x)
_, predictions = torch.max(output, dim=1)
result = [label_to_emotion[i] for i in (map(int, predictions))]
return result
def predict_proba(self, x):
with torch.no_grad():
output = self.forward(x)
#_, predictions = torch.max(output, dim=1)
return self.softmax(output)
class MambaBlock(nn.Module):
def __init__(self, d_input, d_model, d_state=16, d_discr=None, ker_size=4):
super().__init__()
d_discr = d_discr if d_discr is not None else d_model // 16
self.in_proj = nn.Linear(d_input, 2 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_input, bias=False)
self.s_B = nn.Linear(d_model, d_state, bias=False)
self.s_C = nn.Linear(d_model, d_state, bias=False)
self.s_D = nn.Sequential(nn.Linear(d_model, d_discr, bias=False), nn.Linear(d_discr, d_model, bias=False),)
self.conv = nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=ker_size,
padding=ker_size - 1,
groups=d_model,
bias=True,
)
self.A = nn.Parameter(torch.arange(1, d_state + 1, dtype=torch.float).repeat(d_model, 1))
self.D = nn.Parameter(torch.ones(d_model, dtype=torch.float))
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def forward(self, seq, cache=None):
b, l, d = seq.shape
(prev_hid, prev_inp) = cache if cache is not None else (None, None)
a, b = self.in_proj(seq).chunk(2, dim=-1)
x = rearrange(a, 'b l d -> b d l')
x = x if prev_inp is None else torch.cat((prev_inp, x), dim=-1)
a = self.conv(x)[..., :l]
a = rearrange(a, 'b d l -> b l d')
a = silu(a)
a, hid = self.ssm(a, prev_hid=prev_hid)
b = silu(b)
out = a * b
out = self.out_proj(out)
if cache:
cache = (hid.squeeze(), x[..., 1:])
return out, cache
def ssm(self, seq, prev_hid):
A = -self.A
D = +self.D
B = self.s_B(seq)
C = self.s_C(seq)
s = softplus(D + self.s_D(seq))
A_bar = einsum(torch.exp(A), s, 'd s, b l d -> b l d s')
B_bar = einsum( B, s, 'b l s, b l d -> b l d s')
X_bar = einsum(B_bar, seq, 'b l d s, b l d -> b l d s')
hid = self._hid_states(A_bar, X_bar, prev_hid=prev_hid)
out = einsum(hid, C, 'b l d s, b l s -> b l d')
out = out + D * seq
return out, hid
def _hid_states(self, A, X, prev_hid=None):
b, l, d, s = A.shape
A = rearrange(A, 'b l d s -> l b d s')
X = rearrange(X, 'b l d s -> l b d s')
if prev_hid is not None:
return rearrange(A * prev_hid + X, 'l b d s -> b l d s')
h = torch.zeros(b, d, s, device=self.device)
return torch.stack([h := A_t * h + X_t for A_t, X_t in zip(A, X)], dim=1)