Ming-Lite-Omni / modeling_utils.py
LandyGuo
update 20250516 version
81a8221
raw
history blame
44 kB
#!/usr/bin/env python
# coding=utf-8
# @Author: jiangpeijie.jpj
# @Date: Mon 4 Dec 2023 05:21:28 PM CST
import logging
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
import torch.distributed as dist
from numpy import random
from torch import nn
from torch.nn import CrossEntropyLoss
from whisper.model import AudioEncoder
from transformers.activations import ACT2CLS, ClassInstantier
try:
from atorch.distributed.distributed import parallel_group, parallel_group_size
except Exception:
parallel_group = None
parallel_group_size = None
# ## Activations
class SwiGLUActivatition(nn.Module):
def forward(self, input):
input = torch.chunk(input, 2, dim=-1)
return F.silu(input[0]) * input[1]
ACT2CLS["swiglu"] = SwiGLUActivatition
ACT2FN = ClassInstantier(ACT2CLS)
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
# For backwards compatibility with: from activations import gelu_python
gelu_python = get_activation("gelu_python")
gelu_new = get_activation("gelu_new")
gelu = get_activation("gelu")
gelu_fast = get_activation("gelu_fast")
quick_gelu = get_activation("quick_gelu")
silu = get_activation("silu")
mish = get_activation("mish")
linear_act = get_activation("linear")
swiglu = get_activation("swiglu")
# Rotary Position Embedding Utils
def find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base))
def find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
"""Find dim range bounds based on rotations"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
# dim=-1 triggers a bug in earlier torch versions
return torch.cat((-x2, x1), dim=x1.ndim - 1)
# Comment torchscript func for accurate calculate
# @torch.jit.script
def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
# position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), F.embedding(
position_id, sin.squeeze(1)
).unsqueeze(2)
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return q, k
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
super().__init__()
self.dim = dim
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
# inv_freq 保留float精度,避免bf16损失
# inv_freq = inv_freq.to(precision)
self.learnable = learnable
if learnable:
self.inv_freq = torch.nn.Parameter(inv_freq)
self.max_seq_len_cached = None
else:
self.register_buffer('inv_freq', inv_freq)
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
pass
def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
# freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
freqs = torch.outer(t, inv_freq.to(x.device))
assert freqs.dtype == torch.float32
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
def _apply(self, fn, *args, **kwargs):
if self.cos_cached is not None:
self.cos_cached = fn(self.cos_cached)
if self.sin_cached is not None:
self.sin_cached = fn(self.sin_cached)
return super()._apply(fn, *args, **kwargs)
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling."""
def __init__(
self, dim, base=10000, precision=torch.half, learnable=False, max_embedding_length=2048, scaling_factor=1.0
):
self.scaling_factor = scaling_factor
self.max_embedding_length = max_embedding_length
super().__init__(dim, base, precision, learnable)
def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
t = t / self.scaling_factor
# freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
freqs = torch.outer(t, inv_freq.to(x.device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
class NTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling."""
def __init__(
self, dim, base=10000, precision=torch.half, learnable=False, max_embedding_length=2048, scaling_factor=1.0
):
self.scaling_factor = scaling_factor
self.max_embedding_length = max_embedding_length
super().__init__(dim, base, precision, learnable)
def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = None if self.learnable else seq_len
base = self.base
if seq_len > self.max_embedding_length:
base = self.base * (
(self.scaling_factor * seq_len / self.max_embedding_length) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
# freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
freqs = torch.outer(t, inv_freq.to(x.device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
class DynamicYaRNScaledRotaryEmbedding(RotaryEmbedding):
def __init__(
self,
dim,
base=10000,
precision=torch.half,
learnable=False,
max_embedding_length=2048,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
):
self.max_embedding_length = max_embedding_length
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
super().__init__(dim, base, precision, learnable)
def forward(self, x, seq_dim=1, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len is None:
seq_len = x.shape[seq_dim]
if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
self.max_seq_len_cached = seq_len
if seq_len > self.max_embedding_length:
self.yarn(seq_len / self.max_embedding_length, x.device)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
# freqs = torch.einsum('i,j->ij', t, inv_freq.to(x.device))
freqs = torch.outer(t, self.inv_freq.to(x.device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
cos_cached = emb.cos()[:, None, :]
sin_cached = emb.sin()[:, None, :]
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
if self.learnable:
return cos_cached, sin_cached
self.cos_cached, self.sin_cached = cos_cached, sin_cached
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
def yarn(self, scale, device):
pos_freqs = self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scale * pos_freqs)
low, high = find_correction_range(
self.beta_fast, self.beta_slow, self.dim, self.base, self.max_embedding_length
)
inv_freq_mask = (
1 - linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
self.register_buffer("inv_freq", inv_freq) # Get n-d magnitude scaling corrected for interpolation
# ## LongGLM Utils
@dataclass
class LongGLMMemCache:
"""
Class with LongLlama's memory cache
Args:
key (`torch.FloatTensor` of shape `(batch_size, mem_length, head_nums, embed_size_per_head)`)
value (`torch.FloatTensor` of shape `(batch_size, mem_length, head_nums, embed_size_per_head)`)
masks (`torch.FloatTensor` of shape `(batch_size, 1, mem_length, 1)`)
For masking out parts of memory
"""
key: torch.FloatTensor
value: torch.FloatTensor
masks: torch.FloatTensor
def mem_apply_update(prev_external_mem_cache: LongGLMMemCache, new_mem_content: LongGLMMemCache):
def update_one(prev, new, dim=1):
if len(prev.shape) != len(new.shape):
raise ValueError(f"Memory cache content should be consistent in shape got {prev.shape} {new.shape}")
return torch.concat([prev, new], dim=dim)
insert_size = new_mem_content.key.shape[1]
assert new_mem_content.key.shape[1] == new_mem_content.value.shape[1]
if new_mem_content.masks.shape[-2] != insert_size:
raise ValueError("Inconsistent mem_length in new_mem_content")
return LongGLMMemCache(
key=update_one(prev_external_mem_cache.key, new_mem_content.key),
value=update_one(prev_external_mem_cache.value, new_mem_content.value),
masks=update_one(prev_external_mem_cache.masks, new_mem_content.masks, dim=-2),
)
def generate_prompt_keypass(n_garbage: int, seed: int = None):
"""Generates a text file and inserts an execute line at a random position."""
if seed is not None:
rnd_state = random.get_state()
random.seed(seed)
n_garbage_prefix = random.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix
task_description = "在下文的大量无关紧要的文字中隐藏着一个非常重要的信息,请找到并记住它们,后面将使用到这个信息。"
garbage = "草是绿色的。天空是蓝色的。太阳是黄色的。我们走。我们离开又回来了。"
garbage_inf = "".join([garbage] * 5000)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = random.randint(1, 50000)
information_line = (
f"以下是本段文本的重要信息: “通行密码是'{pass_key}',这是非常重要的信息,请记住'{pass_key}'是通行密码。”"
)
information_line = "\n".join([information_line] * 3)
final_question = "请问通行密码是多少?"
lines = [
task_description,
garbage_prefix,
information_line,
garbage_suffix,
final_question,
]
if seed is not None:
random.set_state(rnd_state)
return "\n".join(lines), str(pass_key)
# ## Loss Fuctions
def _unpack_router_logits(router_outputs):
"""
Unpack the router tuple for blance loss calculation.
"""
total_router_logits = []
total_expert_indexes = []
for router_output in router_outputs:
if router_output[0] is not None:
router_logits, expert_indexes = router_output
total_router_logits.append(router_logits.unsqueeze(0))
total_expert_indexes.append(expert_indexes.unsqueeze(0))
# return torch.cat(total_router_logits, dim=0), torch.cat(total_expert_indexes, dim=0)
return torch.cat(total_router_logits, dim=0), total_expert_indexes
def load_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor, labels: torch.Tensor) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
router_probs (`torch.Tensor`):
Probability assigned to each expert per token. Shape: [num_layers, batch_size, seqeunce_length, num_experts].
expert_indices (`torch.Tensor`):
Indices tensor of shape [num_layers, batch_size, seqeunce_length] identifying the selected expert for a given token.
Returns:
The auxiliary loss.
"""
num_layers, _, seq_len, num_experts = router_probs.shape
num_experts = router_probs.shape[-1]
new_labels = labels.clone().detach()
##
for batch_tensor in new_labels:
neg_mask = batch_tensor == -100
diff_neg_ones = torch.diff(neg_mask.float())
start_pos = torch.where(diff_neg_ones == 1.0)[0] # 找到-1序列开始的位置
if start_pos.nelement() == 0: # 如果没有找到开始位置,可能需要根据实际情况调整
pass
else:
last_start = start_pos[-1] # 需要修改的最后一串-1的开始位置
batch_tensor[:last_start] = 0 # 将这部分-1全部改为0
new_labels = new_labels.to(torch.int64)
# cast the expert indices to int64, otherwise one-hot encoding will fail
if expert_indices.dtype != torch.int64:
expert_indices = expert_indices.to(torch.int64)
if len(expert_indices.shape) == 3:
expert_indices = expert_indices.unsqueeze(3)
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
expert_mask = torch.max(expert_mask, axis=-2).values
# cast to float32 otherwise mean will fail
expert_mask = expert_mask.to(torch.float32)
labels_mask = (new_labels[None, ..., None].expand_as(expert_mask) != -100).long()
# sample level balance loss
tokens_per_group_and_expert = torch.sum(expert_mask * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
router_prob_per_group_and_expert = torch.sum(router_probs * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2)
tmp_per_group_and_expert = torch.mean(expert_mask)
return torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert) * (num_experts**2)
'''
# batch level balance loss
expert_mask = expert_mask.view(num_layers, -1, num_experts).detach()
labels_mask = labels_mask.view(num_layers, -1, num_experts).detach()
origin_mask = labels_mask.clone()
router_probs = router_probs.view(num_layers, -1, num_experts)
from antllm.utils import mpu
torch.distributed.all_reduce(expert_mask, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(labels_mask, group=mpu.get_data_parallel_group())
labels_mask = labels_mask.bool().long()
world_size = torch.distributed.get_world_size()
tokens_per_group_and_expert = (
torch.sum(expert_mask * labels_mask, dim=-2) / torch.sum(labels_mask, dim=-2) / world_size
)
router_prob_per_group_and_expert = torch.sum(router_probs * origin_mask, dim=-2) / torch.sum(origin_mask, dim=-2)
layer_loss = tokens_per_group_and_expert * router_prob_per_group_and_expert
loss = layer_loss.sum(-1).mean() * num_experts
return loss
'''
def group_level_device_balancing_loss_func(router_probs: torch.Tensor, expert_indices: torch.Tensor) -> float:
r"""
Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss
function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between
experts is too unbalanced.
Args:
router_probs (`torch.Tensor`):
Probability assigned to each expert per token. Shape: [num_layers, batch_size, seqeunce_length, num_experts].
expert_indices (`torch.Tensor`):
Indices tensor of shape [num_layers, batch_size, seqeunce_length] identifying the selected expert for a given token.
Returns:
The auxiliary loss.
"""
assert parallel_group is not None and parallel_group_size is not None
num_layers, _, seq_len, num_experts = router_probs.shape
# cast the expert indices to int64, otherwise one-hot encoding will fail
if expert_indices.dtype != torch.int64:
expert_indices = expert_indices.to(torch.int64)
if len(expert_indices.shape) == 3:
expert_indices = expert_indices.unsqueeze(3)
expert_mask = torch.nn.functional.one_hot(expert_indices, num_experts)
# For a given token, determine if it was routed to a given expert.
expert_mask = torch.max(expert_mask, axis=-2).values
# cast to float32 otherwise mean will fail
expert_mask = expert_mask.to(torch.float32)
torch.distributed.all_reduce(expert_mask, group=parallel_group("expert"))
expert_parallel_size = parallel_group_size("expert")
num_experts_per_device = num_experts / expert_parallel_size
# sample level balance loss
expert_mask = torch.sum(
torch.cat(torch.chunk(expert_mask.unsqueeze(-2), expert_parallel_size, dim=-1), dim=-2), dim=-1
)
tokens_per_group_and_device = torch.mean(expert_mask, axis=-2) / expert_parallel_size
router_probs = torch.sum(
torch.cat(torch.chunk(router_probs.unsqueeze(-2), expert_parallel_size, dim=-1), dim=-2), dim=-1
)
router_prob_per_group_and_device = torch.mean(router_probs, axis=-2)
device_loss = tokens_per_group_and_device * router_prob_per_group_and_device * expert_parallel_size
loss = device_loss.sum(-1).mean()
return loss
def router_z_loss_func(router_logits: torch.Tensor, labels: torch.Tensor) -> float:
r"""
Compute the router z-loss implemented in PyTorch.
The router z-loss was introduced in [Designing Effective Sparse Expert Models](https://arxiv.org/abs/2202.08906).
It encourages router logits to remain small in an effort to improve stability.
Args:
router_logits (`float`):
Input logits of shape [num_layers, batch_size, sequence_length, num_experts]
Returns:
Scalar router z-loss.
"""
num_layers, num_groups, tokens_per_group, _ = router_logits.shape
labels_mask = (labels[None, ..., None].expand_as(router_logits) != -100).long()
ori_dtype = router_logits.dtype
if ori_dtype == torch.bfloat16:
loss_func_inputs = (router_logits * labels_mask).to(torch.float32)
else:
loss_func_inputs = router_logits * labels_mask
log_z = torch.logsumexp(loss_func_inputs, dim=-1).to(ori_dtype)
z_loss = log_z**2
# log_z = torch.logsumexp(router_logits * labels_mask, dim=-1)
# z_loss = log_z**2
return torch.sum(z_loss) / (num_layers * num_groups * tokens_per_group)
def auxiliary_loss(outputs, labels):
router_tuple = outputs.router_tuple
balance_loss, z_loss, last_logits_l2_loss = 0.0, 0.0, 0.0
loss = 0
if router_tuple is not None:
router_logits, layer_router_index = _unpack_router_logits(router_tuple)
top1_expert_index = torch.cat(layer_router_index, dim=0)
outputs["layer_expert_index"] = top1_expert_index
z_loss = router_z_loss_func(router_logits, labels)
router_probs = torch.nn.Softmax(dim=-1)(router_logits)
balance_loss = load_balancing_loss_func(router_probs, top1_expert_index, labels)
num_layers = router_probs.shape[0]
num_experts = router_probs.shape[-1]
router_probs_log = router_probs.detach().view(num_layers, -1, num_experts)
router_probs_mean = router_probs_log.mean(1)
router_probs_sort_mean = router_probs_log.sort(-1, descending=True)[0].mean(1)
router_probs_log = torch.stack([router_probs_mean, router_probs_sort_mean], dim=1)
dist.all_reduce(router_probs_log, dist.ReduceOp.SUM)
router_probs_log = router_probs_log / torch.distributed.get_world_size()
if dist.get_rank() == 0:
router_probs_log = router_probs_log.float()
router_probs_log /= router_probs_log.sum(-1, keepdim=True)
outputs["layer_expert_probs"] = router_probs_log.float().cpu()
group_balance_loss = 0
if float(outputs["router_group_balance_loss_alpha"]) > 0:
group_balance_loss = group_level_device_balancing_loss_func(router_probs, top1_expert_index)
loss = (
float(outputs["router_z_loss_alpha"]) * z_loss
+ float(outputs["router_balance_loss_alpha"]) * balance_loss
+ float(outputs["router_group_balance_loss_alpha"]) * group_balance_loss
)
last_logits_l2_loss = 0.0
if float(outputs["last_logits_l2_alpha"]) >= 0:
logits = outputs.logits.view(-1, outputs.logits.size(-1))
labels_mask = (labels.view(-1) != -100).long()
last_logits_l2_loss = torch.sum(torch.linalg.norm(logits.float(), 2.0, dim=-1) * labels_mask) / torch.sum(
labels_mask
)
loss += float(outputs["last_logits_l2_alpha"]) * last_logits_l2_loss
last_logits_l2_loss = last_logits_l2_loss.item()
return loss, balance_loss, z_loss, last_logits_l2_loss
def expert_balanced_auxiliary_cross_entropy(outputs, labels, *args, **kwargs):
"""FOR PRETRAIN ONLY"""
# Output losses without reduction for compute dataset loss
if kwargs.get("output_losses", False):
lm_loss, losses = cross_entropy_loss(outputs.logits, labels, *args, **kwargs)
else:
lm_loss = cross_entropy_loss(outputs.logits, labels, *args, **kwargs)
aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(outputs, labels)
loss = lm_loss + aux_loss
if kwargs.get("output_losses", False):
return loss, lm_loss, balance_loss, z_loss, last_logits_l2_loss, losses
return loss, lm_loss, balance_loss, z_loss, last_logits_l2_loss
def expert_balanced_auxiliary_cross_entropy_for_sft(outputs, labels, *args, **kwargs):
"""FOR SFT ONLY"""
lm_loss = sample_level_cross_entropy(outputs, labels, **kwargs)
aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(outputs, labels)
loss = lm_loss + aux_loss
return loss
def expert_balanced_auxiliary_global_level_cross_entropy(outputs, labels, *args, **kwargs):
"""FOR SFT ONLY"""
lm_loss = global_token_level_cross_entropy(outputs, labels, **kwargs)
aux_loss, balance_loss, z_loss, last_logits_l2_loss = auxiliary_loss(outputs, labels)
loss = lm_loss + aux_loss
return [
loss,
{
'aux_loss': aux_loss,
'balance_loss': balance_loss,
'z_loss': z_loss,
'last_logits_l2_loss': last_logits_l2_loss,
},
]
def cross_entropy_loss(logits, labels, loss_mask, *args, **kwargs):
if kwargs["use_atorch_cross_entropy"]:
from atorch.modules.transformer import losses as atorch_loss
losses = atorch_loss.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1))
else:
losses = torch.nn.CrossEntropyLoss(reduction="none")(logits.view(-1, logits.size(-1)), labels.view(-1))
loss = torch.sum(losses * loss_mask.view(-1))
if loss_mask.sum().item() > 0:
loss = loss / loss_mask.sum()
if kwargs.get("output_losses", False):
return loss, losses
return loss
def local_token_level_cross_entropy(outputs, labels, **kwargs):
# return outputs.loss / torch.distributed.get_world_size()
# 在每个batch内部做token-level的平均,然后在所有batch间做平均
# return outputs.loss
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1))
return loss
def mini_batch_token_level_cross_entropy(outputs, labels, mini_batch=1, **kwargs):
# 这个loss会先把batch分成小的mini_batch,在mini_batch内做个token-level的平均,然后做所有卡之间的平均
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
if labels.shape[0] % mini_batch != 0:
# 如果batch % mini_batch != 0, 则不切分计算. 有的数据量一个epoch结束的时候可能会出现这个情况
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1))
else:
loss = loss_fct(
outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1)
).reshape(labels.shape[0] // mini_batch, -1)
labels = labels.reshape(labels.shape[0] // mini_batch, -1)
loss = loss.sum(-1) / (labels != -100).sum(-1)
loss = loss.mean()
return loss
def sample_level_cross_entropy(outputs, labels, **kwargs):
# 先对所有样本字token-level的平均,然后计算所有sample的平均值
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
loss = loss_fct(
outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1)
).reshape(labels.shape[0], -1)
loss = loss.sum(-1) / (labels != -100).sum(-1)
loss = loss.mean()
return loss
def global_token_level_cross_entropy(outputs, labels, **kwargs):
# 对所有样本一起做token-level的平均
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
loss = loss_fct(
outputs.logits.contiguous().view(-1, outputs.logits.size(-1)), labels.contiguous().view(-1)
).reshape(labels.shape[0], -1)
num_tokens = (loss != 0).sum()
loss = loss.sum()
num_tokens_tensor = torch.zeros([1], device=loss.device, dtype=loss.dtype)
num_tokens_tensor[0] = num_tokens.item()
torch.distributed.all_reduce(num_tokens_tensor)
global_num_tokens = num_tokens_tensor.sum()
torch.distributed.barrier()
# global_num_tokens是全局的token数,因为在梯度更新的时候回自动对所有卡求mean
# 所有这里要乘一个world_size
loss = loss.sum() / global_num_tokens * torch.distributed.get_world_size()
return loss
LOSS_MAP = {
'local_token_level_cross_entropy': local_token_level_cross_entropy,
'mini_batch_token_level_cross_entropy': mini_batch_token_level_cross_entropy,
'sample_level_cross_entropy': sample_level_cross_entropy,
'global_token_level_cross_entropy': global_token_level_cross_entropy,
"moe_auxiliary": expert_balanced_auxiliary_cross_entropy,
"moe_auxiliary_sft": expert_balanced_auxiliary_cross_entropy_for_sft,
"pretrain_default": cross_entropy_loss,
"moe_auxiliary_global_token_level": expert_balanced_auxiliary_global_level_cross_entropy,
}
class Transpose(nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return x.transpose(self.dim0, self.dim1)
def patch_continuous_features(
input_embeddings: torch.Tensor,
placeholder_loc_lens: torch.Tensor,
encoded_feats: torch.Tensor,
encoded_feat_lens: torch.Tensor,
):
"""
Patch continuous features into input embeddings, while keeping a valid gradient flow.
input_embeddings: torch.Tensor, size = [B, C?, T, D]
placeholder_loc_lens: torch.LongTensor, size = [B, N, 2]
Each 2-tuple represents (start, length) of a placeholder.
encoded_feats: torch.Tensor, size = [B, L1 + L2 + ... + LN, ...]
encoded_feat_lens: torch.LongTensor, size = [B, N]
Example ('X' for patch placeholder tokens):
Inputs:
input_embeddings = [[1, 2, 3, X, X, X, 4, 5, 6, X, X, X, 7, 8]]
placeholder_loc_lens = [[3, 3], [9, 3]]
encoded_feats = [[A, A, A, B, B]]
encoded_feat_lens = [[3, 2]]
Outputs:
embeddings = [[1, 2, 3, A, A, A, 4, 5, 6, B, B, X, 7, 8]]
"""
batch_size = input_embeddings.size(0)
audio_feats_mask = torch.zeros_like(input_embeddings, dtype=torch.bool)
audio_feats_buffer = []
for i in range(batch_size):
sample_len = 0
audio_feat_start = 0
audio_feat_buffer = []
for j in range(placeholder_loc_lens.shape[1]):
placeholder_start: int = int(placeholder_loc_lens[i, j, 0].item())
placeholder_len: int = int(placeholder_loc_lens[i, j, 1].item())
if placeholder_len <= 0:
break
feat_len = int(encoded_feat_lens[i, j].item())
real_feat_len = feat_len
if feat_len > placeholder_len:
# logger.warning(
# f"Feature length ({feat_len}) > placeholder length ({placeholder_len}). This is not expected. Please "
# "check the implementation of estimate_audio_feature_length(). We truncate the feature to avoid errors."
# )
feat_len = placeholder_len
if placeholder_start > sample_len:
audio_feat_buffer.append(input_embeddings.new_zeros((placeholder_start - sample_len, input_embeddings.shape[2])))
sample_len = placeholder_start
audio_feat_buffer.append(encoded_feats[i, audio_feat_start:audio_feat_start + feat_len])
if feat_len < placeholder_len:
audio_feat_buffer.append(encoded_feats.new_zeros(placeholder_len - feat_len))
audio_feats_mask[i, sample_len:sample_len + feat_len] = 1
audio_feat_start += real_feat_len
sample_len += placeholder_len
if sample_len < input_embeddings.shape[1]:
audio_feat_buffer.append(
input_embeddings.new_zeros((input_embeddings.shape[1] - sample_len, input_embeddings.shape[2]))
)
audio_feats_buffer.append(torch.cat(audio_feat_buffer))
audio_feats_buffer = torch.stack(audio_feats_buffer, dim=0)
embeddings = audio_feats_buffer * audio_feats_mask + input_embeddings * ~audio_feats_mask
return embeddings
def unwrap_feats(feats: torch.Tensor, feats_lengths: torch.Tensor):
"""
The input feats are in the "wrapped" format, which means that features from (at most) N audios are concatenated
as a single sample feats[i]. In this case, each row of feats_lengths contains the lengths of the concatenated
feature. This function unwraps the features.
For samples with less than N segments, one should pad feats_lengths with 0. The result will contain valid
segments only.
feats: torch.Tensor, size = [B, L1 + L2 + ... + LN, ...]
feats_lengths: torch.LongTensor, size = [B, N]
Example ('X' for padding):
Inputs:
feats = [[A, A, A, A, X],
[B, B, C, C, C]]
feats_lengths = [[4, 0],
[2, 3]]
Outputs:
feat_segs = [[A, A, A, A],
[B, B, X, X],
[C, C, C, X]]
feat_seg_lengths = [4, 2, 3]
"""
feat_segs = []
feat_seg_lengths = []
for i in range(feats_lengths.shape[0]):
feat_index = 0
for j in range(feats_lengths.shape[1]):
feat_len = feats_lengths[i, j].item()
if feat_len == 0: break
feat_segs.append(feats[i, feat_index:feat_index + feat_len])
feat_seg_lengths.append(feat_len)
feat_index += feat_len
feat_segs_batch = torch.nn.utils.rnn.pad_sequence(feat_segs, True).to(feats.device)
feat_seg_lengths = torch.tensor(feat_seg_lengths, dtype=torch.long, device=feats.device)
return feat_segs_batch, feat_seg_lengths
def wrap_feats(feat_segs: torch.Tensor, feats_lengths: torch.Tensor, feats_seg_lengths: Optional[torch.Tensor] = None):
"""
Wrap segmented features back to the wrapped format.
This function is the inverse operation of unwrap_feats(). See its documentation for details.
Note that the feats_lengths value does not matter a lot. We only check the location of the first 0 to determine the
number of feature segments.
"""
feat_idx = 0
feats_buffer = []
feats_locs_buffer = []
feats_lengths_buffer = []
for i in range(feats_lengths.shape[0]):
feat_buffer = []
feat_locs_buffer = []
feat_lengths_buffer = []
feat_total_len = 0
for j in range(feats_lengths.shape[1]):
feat_len = feats_lengths[i, j].item()
if feat_len == 0:
break
if feats_seg_lengths is not None:
feat_len = feats_seg_lengths[feat_idx].item()
feat_buffer.append(feat_segs[feat_idx, :feat_len])
feat_locs_buffer.append(feat_total_len)
feat_lengths_buffer.append(feat_len)
feat_idx += 1
feat_total_len += feat_len
feats_buffer.append(torch.cat(feat_buffer))
feats_locs_buffer.append(torch.tensor(feat_locs_buffer, dtype=torch.long))
feats_lengths_buffer.append(torch.tensor(feat_lengths_buffer, dtype=torch.long))
feats = torch.nn.utils.rnn.pad_sequence(feats_buffer, True).to(feat_segs.device)
feats_locs = torch.nn.utils.rnn.pad_sequence(feats_locs_buffer, True).to(feats_lengths.device)
feats_new_lengths = torch.nn.utils.rnn.pad_sequence(feats_lengths_buffer, True).to(feats_lengths.device)
return feats, feats_locs, feats_new_lengths
def encode_audio_segments(
encoder,
proj_layer,
wav_feats=None,
wav_feats_lengths=None,
waveforms=None,
waveforms_lengths=None,
use_waveform=False,
audio_config=None,
whisper_config=None,
use_whisper_encoder=False
):
"""
Apply audio encoder to input audio features in wrapped format.
See the documentation of unwrap_feats() for details about 'wrapped format'.
"""
# Forward audio encoder.
if use_waveform:
assert waveforms is not None and waveforms_lengths is not None
# Unwrap the waveforms so each waveform is placed at an independent row.
waveform_segs_batch, waveform_seg_lengths = unwrap_feats(waveforms, waveforms_lengths)
audio_feats_seg, audio_feat_seg_lengths = encoder(waveform_segs_batch, waveform_seg_lengths)[:2]
else:
assert wav_feats is not None and wav_feats_lengths is not None
# Unwrap the features so the feature of each waveform is placed at an independent row.
feat_segs_batch, feat_seg_lengths = unwrap_feats(wav_feats, wav_feats_lengths)
if use_whisper_encoder:
assert isinstance(encoder, AudioEncoder)
assert whisper_config is not None
# for whisper encoder
# feat_segs_batch: [B, T, n_mels]
# feat_seg_lengths: [B]
audio_feats_seg = encoder(feat_segs_batch)
audio_feats_seg_proj = proj_layer(audio_feats_seg.transpose(-1, -2)).transpose(-1, -2)
feat_seg_lengths = feat_seg_lengths.to(feat_segs_batch.device)
# whisper encoder conv
audio_feat_seg_lengths = (feat_seg_lengths - 3 + 2 * 1) // 2 + 1
# project layer conv
audio_feat_seg_lengths = (audio_feat_seg_lengths - whisper_config.ds_kernel_size + 2 *
(whisper_config.ds_kernel_size//2)) // whisper_config.ds_stride + 1
else:
audio_feats_seg, audio_feat_seg_lengths = encoder(feat_segs_batch, feat_seg_lengths)[:2]
audio_feats_seg_proj = proj_layer(audio_feats_seg.transpose(-1, -2)).transpose(-1, -2)
# project layer conv
audio_feat_seg_lengths = (audio_feat_seg_lengths - audio_config.ds_kernel_size + 2 * (
audio_config.ds_kernel_size // 2)) // audio_config.ds_stride + 1
# Wrap the features so the 1st dim represents batch_size.
input_lengths = waveforms_lengths if use_waveform else wav_feats_lengths
assert input_lengths is not None
audio_feats, _, audio_feats_lengths = wrap_feats(audio_feats_seg, input_lengths, audio_feat_seg_lengths)
audio_feats_proj, _, audio_feats_lengths2 = wrap_feats(audio_feats_seg_proj, input_lengths, audio_feat_seg_lengths)
assert torch.all(audio_feats_lengths == audio_feats_lengths2), f"{audio_feats_lengths}, {audio_feats_lengths2}"
return audio_feats_proj, audio_feats, audio_feats_lengths
def patch_continuous_features(
input_embeddings: torch.Tensor,
placeholder_loc_lens: torch.Tensor,
encoded_feats: torch.Tensor,
encoded_feat_lens: torch.Tensor,
):
"""
Patch continuous features into input embeddings, while keeping a valid gradient flow.
input_embeddings: torch.Tensor, size = [B, C?, T, D]
placeholder_loc_lens: torch.LongTensor, size = [B, N, 2]
Each 2-tuple represents (start, length) of a placeholder.
encoded_feats: torch.Tensor, size = [B, L1 + L2 + ... + LN, ...]
encoded_feat_lens: torch.LongTensor, size = [B, N]
Example ('X' for patch placeholder tokens):
Inputs:
input_embeddings = [[1, 2, 3, X, X, X, 4, 5, 6, X, X, X, 7, 8]]
placeholder_loc_lens = [[3, 3], [9, 3]]
encoded_feats = [[A, A, A, B, B]]
encoded_feat_lens = [[3, 2]]
Outputs:
embeddings = [[1, 2, 3, A, A, A, 4, 5, 6, B, B, X, 7, 8]]
"""
batch_size = input_embeddings.size(0)
audio_feats_mask = torch.zeros_like(input_embeddings, dtype=torch.bool)
audio_feats_buffer = []
for i in range(batch_size):
sample_len = 0
audio_feat_start = 0
audio_feat_buffer = []
for j in range(placeholder_loc_lens.shape[1]):
placeholder_start: int = int(placeholder_loc_lens[i, j, 0].item())
placeholder_len: int = int(placeholder_loc_lens[i, j, 1].item())
if placeholder_len <= 0:
break
feat_len = int(encoded_feat_lens[i, j].item())
real_feat_len = feat_len
if feat_len > placeholder_len:
logging.warning(
f"Feature length ({feat_len}) > placeholder length ({placeholder_len}). This is not expected. Please "
"check the implementation of estimate_audio_feature_length(). We truncate the feature to avoid errors."
)
feat_len = placeholder_len
if placeholder_start > sample_len:
audio_feat_buffer.append(input_embeddings.new_zeros((placeholder_start - sample_len, input_embeddings.shape[2])))
sample_len = placeholder_start
audio_feat_buffer.append(encoded_feats[i, audio_feat_start:audio_feat_start + feat_len])
if feat_len < placeholder_len:
audio_feat_buffer.append(encoded_feats.new_zeros(placeholder_len - feat_len))
audio_feats_mask[i, sample_len:sample_len + feat_len] = 1
audio_feat_start += real_feat_len
sample_len += placeholder_len
if sample_len < input_embeddings.shape[1]:
audio_feat_buffer.append(
input_embeddings.new_zeros((input_embeddings.shape[1] - sample_len, input_embeddings.shape[2]))
)
audio_feats_buffer.append(torch.cat(audio_feat_buffer))
audio_feats_buffer = torch.stack(audio_feats_buffer, dim=0)
embeddings = audio_feats_buffer * audio_feats_mask + input_embeddings * ~audio_feats_mask
return embeddings
def build_modality_mask(placeholder_loc_lens: torch.Tensor, shape: torch.Size):
mask = torch.zeros(shape, dtype=torch.bool)
for i in range(placeholder_loc_lens.shape[0]):
for j in range(placeholder_loc_lens.shape[1]):
start: int = int(placeholder_loc_lens[i, j, 0].item())
length: int = int(placeholder_loc_lens[i, j, 1].item())
if length <= 0:
break
mask[i, start:start + length] = True
return mask