Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from models.norm import RMSNorm | |
| from models.rope import precompute_freqs_cis, apply_rotary_emb | |
| import bitsandbytes as bnb | |
| import math | |
| class NormalLinear(nn.Linear): | |
| def reset_parameters(self) -> None: | |
| pass | |
| class BnbInt8Linear(bnb.nn.Linear8bitLt): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(has_fp16_weights=False, threshold=6.0, *args, **kwargs) | |
| def reset_parameters(self) -> None: | |
| pass | |
| def get_linear_layer(use_int8): | |
| if use_int8: | |
| return BnbInt8Linear | |
| return NormalLinear | |
| class WordEmbedding(nn.Module): | |
| def __init__(self, args): | |
| super(WordEmbedding, self).__init__() | |
| self.embedding = nn.Embedding(args.vocab_size, args.emb_size) | |
| def forward(self, src): | |
| emb = self.embedding(src) | |
| return emb | |
| class MultiHeadedAttention(nn.Module): | |
| def __init__(self, args, hidden_size, heads_num, attention_head_size, has_bias=True, use_int8=True): | |
| super(MultiHeadedAttention, self).__init__() | |
| self.heads_num = heads_num | |
| self.per_head_size = attention_head_size | |
| self.inner_hidden_size = heads_num * attention_head_size | |
| Linear = get_linear_layer(use_int8) | |
| self.linear_layers = nn.ModuleList( | |
| [Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)] | |
| ) | |
| self.final_linear = Linear(self.inner_hidden_size, hidden_size, bias=has_bias) | |
| # add cache to reduce compute source. | |
| self.cache_k = torch.zeros( | |
| (args.batch_size, args.seq_length, self.heads_num, self.per_head_size) | |
| ) | |
| self.cache_v = torch.zeros( | |
| (args.batch_size, args.seq_length, self.heads_num, self.per_head_size) | |
| ) | |
| def forward(self, key, value, query, start_pos, continue_exsample, mask, freqs_cis): | |
| batch_size, seq_length, _ = query.size() | |
| heads_num = self.heads_num | |
| per_head_size = self.per_head_size | |
| query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size) \ | |
| for l, x in zip(self.linear_layers, (query, key, value))] | |
| query, key = apply_rotary_emb(query, key, freqs_cis=freqs_cis) | |
| if self.cache_k.device != key.device: | |
| self.cache_k = self.cache_k.to(key) | |
| if self.cache_v.device != value.device: | |
| self.cache_v = self.cache_v.to(value) | |
| self.cache_k[continue_exsample, start_pos: start_pos + seq_length] = key | |
| self.cache_v[continue_exsample, start_pos: start_pos + seq_length] = value | |
| key = self.cache_k[continue_exsample, : start_pos + seq_length] | |
| value = self.cache_v[continue_exsample, : start_pos + seq_length] | |
| query, key, value = [x.transpose(1, 2) for x in (query, key, value)] | |
| scores = torch.matmul(query, key.transpose(-2, -1)) | |
| scores = scores / math.sqrt(float(per_head_size)) | |
| if mask is not None: | |
| scores += mask | |
| # probs = nn.Softmax(dim=-1)(scores) | |
| probs = F.softmax(scores.float(), dim=-1).type_as(query) | |
| output = torch.matmul(probs, value).transpose(1, 2).\ | |
| contiguous().view(batch_size, seq_length, -1) | |
| return self.final_linear(output) | |
| class GatedFeedForward(nn.Module): | |
| def __init__(self, hidden_size, feedforward_size, has_bias=True, use_int8=True): | |
| super(GatedFeedForward, self).__init__() | |
| Linear = get_linear_layer(use_int8) | |
| self.linear_gate = Linear(hidden_size, feedforward_size, bias=has_bias) | |
| self.linear_1 = Linear(hidden_size, feedforward_size, bias=has_bias) | |
| self.linear_2 = Linear(feedforward_size, hidden_size, bias=has_bias) | |
| self.act = F.silu | |
| def forward(self, x): | |
| # gate = self.act(self.linear_gate(x)) | |
| gate = self.act(self.linear_gate(x)).type_as(x) | |
| inter_linear = self.linear_1(x) | |
| inter = gate * inter_linear | |
| output = self.linear_2(inter) | |
| return output | |
| class TransformerLayer(nn.Module): | |
| def __init__(self, args): | |
| super(TransformerLayer, self).__init__() | |
| if hasattr(args, "attention_head_size"): | |
| attention_head_size = args.attention_head_size | |
| else: | |
| attention_head_size = args.hidden_size // args.heads_num | |
| has_bias = bool(1 - args.remove_transformer_bias) | |
| # Multi-head Attention | |
| self.self_attn = MultiHeadedAttention( | |
| args, args.hidden_size, args.heads_num, attention_head_size, has_bias=has_bias, | |
| use_int8=args.use_int8 | |
| ) | |
| # FFN | |
| self.feed_forward = GatedFeedForward( | |
| args.hidden_size, args.feedforward_size, has_bias, use_int8=args.use_int8 | |
| ) | |
| self.layer_norm_1 = RMSNorm(args.hidden_size) | |
| self.layer_norm_2 = RMSNorm(args.hidden_size) | |
| def forward(self, hidden, start_pos, continue_exsample, mask, freqs_cis=None): | |
| inter = self.layer_norm_1(hidden) | |
| inter = self.self_attn(inter, inter, inter, start_pos, continue_exsample, mask, freqs_cis) | |
| hidden = hidden + inter | |
| output = self.layer_norm_2(hidden) | |
| output = self.feed_forward(output) + hidden | |
| return output | |
| class TransformerEncoder(nn.Module): | |
| def __init__(self, args): | |
| super(TransformerEncoder, self).__init__() | |
| self.mask = args.mask | |
| self.layers_num = args.layers_num | |
| self.transformer = nn.ModuleList( | |
| [TransformerLayer(args) for _ in range(self.layers_num)] | |
| ) | |
| self.layer_norm = RMSNorm(args.hidden_size) | |
| self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2) | |
| def forward(self, emb, start_pos, continue_exsample): | |
| batch_size, seq_length, _ = emb.size() | |
| mask = None | |
| if seq_length > 1: | |
| mask = torch.ones(seq_length, seq_length, device=emb.device) | |
| mask = torch.tril(mask) | |
| mask = (1.0 - mask) * -10000 | |
| mask = mask.repeat(batch_size, 1, 1, 1) | |
| hidden = emb | |
| freqs_cis = self.freqs_cis[start_pos: start_pos + seq_length].to(hidden.device) | |
| for i in range(self.layers_num): | |
| hidden = self.transformer[i](hidden, start_pos, continue_exsample, mask, freqs_cis=freqs_cis) | |
| return self.layer_norm(hidden) | |
| class LmOutput(nn.Module): | |
| def __init__(self, args): | |
| super(LmOutput, self).__init__() | |
| # update: lm output not use int8 | |
| Linear = get_linear_layer(False) | |
| self.lm = Linear(args.hidden_size, args.vocab_size, bias=False) | |
| def forward(self, x): | |
| return self.lm(x[:, -1, :]) | |
| class LLaMa(nn.Module): | |
| def __init__(self, args): | |
| super(LLaMa, self).__init__() | |
| self.embedding = WordEmbedding(args) | |
| self.encoder = TransformerEncoder(args) | |
| self.target = LmOutput(args) | |
| #@torch.inference_mode() | |
| def forward(self, src, start_pos, continue_exsample): | |
| emb = self.embedding(src) | |
| output = self.encoder(emb, start_pos, continue_exsample) | |
| output = self.target(output) | |
| return output | |