|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
from typing import Optional, Any, Tuple |
|
from torch import Tensor |
|
from packaging import version |
|
import deepspeed.comm as dist |
|
from deepspeed.accelerator import get_accelerator |
|
|
|
try: |
|
import flash_attn |
|
from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward |
|
flash_attn_version = version.parse(flash_attn.__version__) |
|
except ImportError: |
|
_flash_attn_forward = None |
|
_flash_attn_backward = None |
|
|
|
from einops import rearrange |
|
from .layer import single_all_to_all, apply_rotary_pos_emb |
|
|
|
|
|
def _rotate_half_backward(x): |
|
x = rearrange(x, '... (j d) -> ... j d', j=2) |
|
x1, x2 = x.unbind(dim=-2) |
|
return torch.cat((x2, -x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): |
|
rot_dim = freqs_cos.shape[-1] |
|
grad, grad_pass = grad_output[..., :rot_dim], grad_output[..., rot_dim:] |
|
grad_t = (grad * freqs_cos) + (_rotate_half_backward(grad * freqs_sin)) |
|
grad = grad_t if grad_pass.shape[-1] == 0 else torch.cat((grad_t, grad_pass), dim=-1) |
|
return grad |
|
|
|
|
|
def _update_out_and_lse( |
|
out: torch.Tensor, |
|
lse: torch.Tensor, |
|
block_out: torch.Tensor, |
|
block_lse: torch.Tensor, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
block_out = block_out.to(torch.float32) |
|
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) |
|
|
|
new_lse = lse + torch.log1p(torch.exp(block_lse - lse)) |
|
|
|
out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out |
|
|
|
lse = new_lse |
|
return out, lse |
|
|
|
|
|
def update_out_and_lse( |
|
out: Optional[torch.Tensor], |
|
lse: Optional[torch.Tensor], |
|
block_out: torch.Tensor, |
|
block_lse: torch.Tensor, |
|
slice_=None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if out is None: |
|
if slice_ is not None: |
|
raise RuntimeError("first update_out_and_lse should not pass slice_ args") |
|
out = block_out.to(torch.float32) |
|
lse = block_lse.permute(0, 2, 1).contiguous().unsqueeze(dim=-1).contiguous() |
|
elif slice_ is not None: |
|
slice_out, slice_lse = out[slice_], lse[slice_] |
|
slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) |
|
out[slice_], lse[slice_] = slice_out, slice_lse |
|
else: |
|
out, lse = _update_out_and_lse(out, lse, block_out, block_lse) |
|
return out, lse |
|
|
|
|
|
class FPDT_InputConstruct(torch.nn.Module): |
|
|
|
def __init__(self, tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank) -> None: |
|
|
|
super(FPDT_InputConstruct, self).__init__() |
|
self.tokens = tokens |
|
self.labels = labels |
|
self.loss_mask = loss_mask |
|
self.attention_mask = attention_mask |
|
self.position_ids = position_ids |
|
global_seq_len = tokens.shape[1] |
|
batch_size = tokens.shape[0] |
|
assert global_seq_len % sp_size == 0 |
|
assert global_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0 |
|
num_chunk_per_gpu = global_seq_len // args.ds_sequence_parallel_fpdt_chunk_size |
|
local_seq_len = global_seq_len // sp_size |
|
assert local_seq_len % num_chunk_per_gpu == 0 |
|
|
|
self.num_chunk_per_gpu = num_chunk_per_gpu |
|
self.chunk_size = local_seq_len // num_chunk_per_gpu |
|
self.sp_size = sp_size |
|
self.sp_rank = sp_rank |
|
self.global_seq_len = global_seq_len |
|
self.local_seq_len = local_seq_len |
|
self.batch_size = batch_size |
|
self.device = tokens.device |
|
|
|
def generate(self): |
|
device = self.device |
|
totalChunks = self.global_seq_len // self.chunk_size |
|
token_chunk_idx = torch.arange(self.global_seq_len, device=device, dtype=torch.int) // self.chunk_size |
|
chunk_to_gpu = torch.arange(totalChunks, device=device, dtype=torch.int) |
|
chunk_to_gpu = chunk_to_gpu.reshape(self.num_chunk_per_gpu, -1).t().contiguous() |
|
|
|
gather_chunk = chunk_to_gpu.flatten().unsqueeze(1).contiguous() |
|
mask = gather_chunk == token_chunk_idx |
|
|
|
indices = mask.nonzero(as_tuple=False) |
|
gather_indices = indices[:, 0] |
|
token_chunk_indices = indices[:, 1] |
|
indices = torch.cat([token_chunk_indices[gather_indices == i] for i in range(gather_chunk.shape[0])]) |
|
load_balanced_loss_mask = self.loss_mask[:, indices] if self.loss_mask is not None else self.loss_mask |
|
|
|
indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu * self.sp_rank:self.num_chunk_per_gpu * |
|
(self.sp_rank + 1)].flatten().contiguous() |
|
load_balanced_tokens = self.tokens[:, indices] |
|
load_balanced_labels = self.labels[:, indices] if self.labels is not None else self.labels |
|
|
|
load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else self.attention_mask |
|
load_balanced_position_ids = self.position_ids[:, |
|
indices] if self.position_ids is not None else self.position_ids |
|
|
|
return load_balanced_tokens, load_balanced_labels, load_balanced_loss_mask, load_balanced_attention_mask, load_balanced_position_ids |
|
|
|
|
|
class _FPDTGPUAttentionImpl_(torch.autograd.Function): |
|
generate_vmap_rule = False |
|
|
|
@staticmethod |
|
def forward(ctx: Any, |
|
layernorm_output, |
|
attention_mask, |
|
inference_params, |
|
rotary_pos_emb, |
|
spg, |
|
scatter_idx, |
|
gather_idx, |
|
hidden_size, |
|
projection_size, |
|
hidden_size_per_attention_head, |
|
kv_projection_size, |
|
qkv_linear_weight, |
|
qkv_linear_bias, |
|
dropout, |
|
num_chunks=8, |
|
cpu_offloading=True): |
|
|
|
do_save = layernorm_output.requires_grad |
|
|
|
if rotary_pos_emb is not None: |
|
pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) |
|
ctx.pos_emb_cos = pos_emb_cos |
|
ctx.pos_emb_sin = pos_emb_sin |
|
else: |
|
ctx.pos_emb_cos = None |
|
ctx.pos_emb_sin = None |
|
|
|
with torch.no_grad(): |
|
per_gpu_seq_len = layernorm_output.shape[0] |
|
chunk_size = per_gpu_seq_len // num_chunks |
|
assert chunk_size * num_chunks == per_gpu_seq_len |
|
assert attention_mask is None |
|
ctx.num_chunks = num_chunks |
|
ctx.cpu_offloading = cpu_offloading |
|
ctx.spg = spg |
|
ctx.scatter_idx = scatter_idx |
|
ctx.gather_idx = gather_idx |
|
|
|
device = get_accelerator().current_device_name() |
|
ctx.device = device |
|
ctx.dtype = layernorm_output.dtype |
|
ctx.projection_size = projection_size |
|
ctx.kv_projection_size = kv_projection_size |
|
|
|
global_q = [] |
|
global_k = [] |
|
global_v = [] |
|
|
|
ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) |
|
|
|
ctx.dropout_p = dropout |
|
ctx.window_size = (-1, -1) |
|
ctx.alibi_slopes = None |
|
|
|
batch_size = layernorm_output.shape[1] |
|
|
|
global_o = [None for _ in range(num_chunks)] |
|
global_lse = [None for _ in range(num_chunks)] |
|
|
|
for i in range(num_chunks): |
|
|
|
st = chunk_size * i |
|
ed = st + chunk_size |
|
|
|
qkv_chunk = torch.matmul(layernorm_output[st:ed], qkv_linear_weight.t()) + qkv_linear_bias |
|
|
|
q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( |
|
qkv_chunk.shape[0], qkv_chunk.shape[1], -1, |
|
hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() |
|
q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) |
|
global_q_chunk_len = q_chunk.shape[1] |
|
if rotary_pos_emb is not None: |
|
q_chunk = apply_rotary_pos_emb(q_chunk, |
|
pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], |
|
pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) |
|
global_q.append(q_chunk) |
|
|
|
k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( |
|
qkv_chunk.shape[0], qkv_chunk.shape[1], -1, |
|
hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() |
|
k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) |
|
if rotary_pos_emb is not None: |
|
k_chunk = apply_rotary_pos_emb(k_chunk, |
|
pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], |
|
pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) |
|
global_k.append(k_chunk) |
|
|
|
v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( |
|
qkv_chunk.shape[0], qkv_chunk.shape[1], -1, |
|
hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() |
|
v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) |
|
global_v.append(v_chunk) |
|
|
|
for k_i in range(len(global_k)): |
|
causal_chunk = i == k_i |
|
if flash_attn_version >= version.parse("2.6.0"): |
|
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], |
|
global_k[k_i], |
|
global_v[k_i], |
|
ctx.dropout_p, |
|
ctx.softmax_scale, |
|
causal=causal_chunk, |
|
window_size=ctx.window_size, |
|
softcap=0.0, |
|
alibi_slopes=ctx.alibi_slopes, |
|
return_softmax=False) |
|
else: |
|
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], |
|
global_k[k_i], |
|
global_v[k_i], |
|
ctx.dropout_p, |
|
ctx.softmax_scale, |
|
causal=causal_chunk, |
|
window_size=ctx.window_size, |
|
alibi_slopes=ctx.alibi_slopes, |
|
return_softmax=False) |
|
|
|
global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) |
|
|
|
global_o[i] = global_o[i].to(q_chunk.dtype) |
|
|
|
output = [None for i in range(num_chunks)] |
|
|
|
for i in range(num_chunks): |
|
global_lse[i] = global_lse[i][:, :, :, 0].permute(0, 2, 1).contiguous() |
|
output[i] = single_all_to_all(global_o[i].to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) |
|
output = torch.cat(output, dim=1) |
|
|
|
head_dim = output.shape[-1] |
|
|
|
if do_save: |
|
ctx.save_for_backward(layernorm_output) |
|
ctx.global_q = global_q |
|
ctx.global_k = global_k |
|
ctx.global_v = global_v |
|
ctx.attn_output = global_o |
|
ctx.attn_lse = global_lse |
|
ctx.head_dim = head_dim |
|
ctx.batch_size = batch_size |
|
|
|
ctx.qkv_linear_weight = qkv_linear_weight |
|
ctx.qkv_linear_bias = qkv_linear_bias |
|
|
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
|
|
num_chunks = ctx.num_chunks |
|
device = ctx.device |
|
dtype = ctx.dtype |
|
spg = ctx.spg |
|
scatter_idx = ctx.scatter_idx |
|
gather_idx = ctx.gather_idx |
|
softmax_scale = ctx.softmax_scale |
|
dropout_p = ctx.dropout_p |
|
window_size = ctx.window_size |
|
alibi_slopes = ctx.alibi_slopes |
|
|
|
projection_size = ctx.projection_size |
|
kv_projection_size = ctx.kv_projection_size |
|
|
|
layernorm_output = ctx.saved_tensors[0] |
|
|
|
global_q = ctx.global_q |
|
global_k = ctx.global_k |
|
global_v = ctx.global_v |
|
attn_output = ctx.attn_output |
|
lse = ctx.attn_lse |
|
|
|
qkv_linear_weight = ctx.qkv_linear_weight |
|
qkv_linear_bias = ctx.qkv_linear_bias |
|
|
|
input_chunk_size = layernorm_output.shape[0] // num_chunks |
|
grad_layernorm_output = [ |
|
torch.zeros((input_chunk_size, layernorm_output.shape[1], layernorm_output.shape[2]), |
|
device=device, |
|
dtype=dtype) for _ in range(num_chunks) |
|
] |
|
|
|
grad_global_attn_output = [] |
|
chunk_size = grad_output.shape[1] // num_chunks |
|
|
|
for i in range(num_chunks): |
|
st = chunk_size * i |
|
ed = st + chunk_size |
|
grad_global_attn_output.append( |
|
single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, 0, spg)) |
|
|
|
del grad_output |
|
|
|
dq = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] |
|
dk = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] |
|
dv = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] |
|
|
|
grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, |
|
device=qkv_linear_weight.device, |
|
dtype=torch.float) |
|
grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) |
|
|
|
for i in range(num_chunks): |
|
k_chunk = global_k[i] |
|
v_chunk = global_v[i] |
|
|
|
for q_i in range(num_chunks): |
|
no_computation = q_i < i |
|
if no_computation: |
|
continue |
|
|
|
causal_chunk = q_i == i |
|
|
|
q_chunk = global_q[q_i] |
|
attn_output_chunk = attn_output[q_i] |
|
lse_chunk = lse[q_i] |
|
d_out = grad_global_attn_output[q_i] |
|
|
|
dq_this = torch.zeros(global_q[0].shape, dtype=dtype, device=device) |
|
dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) |
|
dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) |
|
|
|
if flash_attn_version >= version.parse("2.6.0"): |
|
_flash_attn_backward(d_out, |
|
q_chunk, |
|
k_chunk, |
|
v_chunk, |
|
attn_output_chunk, |
|
lse_chunk, |
|
dq_this, |
|
dk_this, |
|
dv_this, |
|
dropout_p, |
|
softmax_scale, |
|
causal_chunk, |
|
window_size, |
|
softcap=0.0, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=False, |
|
rng_state=None) |
|
else: |
|
_flash_attn_backward(d_out, |
|
q_chunk, |
|
k_chunk, |
|
v_chunk, |
|
attn_output_chunk, |
|
lse_chunk, |
|
dq_this, |
|
dk_this, |
|
dv_this, |
|
dropout_p, |
|
softmax_scale, |
|
causal_chunk, |
|
window_size, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=False, |
|
rng_state=None) |
|
|
|
dq[q_i].add_(dq_this.to(torch.float)) |
|
dk[i].add_(dk_this.to(torch.float)) |
|
dv[i].add_(dv_this.to(torch.float)) |
|
|
|
dk_seq_len = dk[i].shape[1] |
|
|
|
if ctx.pos_emb_cos is not None: |
|
dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), |
|
ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], |
|
ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) |
|
else: |
|
dk[i] = dk[i].to(dtype) |
|
dv[i] = dv[i].to(dtype) |
|
dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, 0, spg) |
|
dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, 0, spg) |
|
|
|
input_st = i * input_chunk_size |
|
input_ed = input_st + input_chunk_size |
|
|
|
input_chunk = layernorm_output[input_st:input_ed].reshape(-1, layernorm_output.shape[-1]) |
|
|
|
dk[i] = dk[i].flatten(2).permute(1, 0, 2) |
|
dv[i] = dv[i].flatten(2).permute(1, 0, 2) |
|
l, b = dk[i].shape[0], dk[i].shape[1] |
|
grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( |
|
torch.matmul(dk[i].reshape(l * b, -1).t(), input_chunk)) |
|
grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( |
|
torch.matmul(dv[i].reshape(l * b, -1).t(), input_chunk)) |
|
grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk[i].sum(0).sum(0)) |
|
grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv[i].sum(0).sum(0)) |
|
|
|
grad_layernorm_output[i].add_( |
|
torch.matmul(dk[i], qkv_linear_weight[projection_size:projection_size + kv_projection_size])) |
|
grad_layernorm_output[i].add_(torch.matmul(dv[i], |
|
qkv_linear_weight[projection_size + kv_projection_size:])) |
|
|
|
dk[i] = None |
|
dv[i] = None |
|
|
|
for i in range(num_chunks): |
|
dq_seq_len = dq[i].shape[1] |
|
if ctx.pos_emb_cos is not None: |
|
dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), |
|
ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], |
|
ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) |
|
else: |
|
dq[i] = dq[i].to(dtype) |
|
dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) |
|
|
|
input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) |
|
layernorm_output = layernorm_output[input_chunk_size:] |
|
|
|
dq[i] = dq[i].flatten(2).permute(1, 0, 2) |
|
l, b = dq[i].shape[0], dq[i].shape[1] |
|
grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq[i].reshape(l * b, -1).t(), input_chunk)) |
|
grad_qkv_linear_bias[:projection_size].add_(dq[i].sum(0).sum(0)) |
|
|
|
grad_layernorm_output[i].add_(torch.matmul(dq[i], qkv_linear_weight[:projection_size])) |
|
|
|
dq[i] = None |
|
|
|
return torch.cat( |
|
grad_layernorm_output, |
|
dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( |
|
dtype), grad_qkv_linear_bias.to(dtype), None, None, None |
|
|
|
|
|
class SequenceChunk: |
|
|
|
def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): |
|
|
|
self.chunk_shape = chunk.shape |
|
self.chunk_dtype = chunk.dtype |
|
self.device = chunk.device if device is None else device |
|
|
|
cpu_chunk = torch.empty(chunk.shape, dtype=chunk.dtype, device='cpu', pin_memory=True) |
|
|
|
if get_accelerator().on_accelerator(chunk): |
|
cpu_chunk.copy_(chunk, non_blocking=True) |
|
else: |
|
cpu_chunk = chunk |
|
|
|
self.cpu_chunk = cpu_chunk |
|
|
|
self.gpu_chunk = chunk if is_in_use else None |
|
|
|
def load_to_gpu(self): |
|
assert self.gpu_chunk is None |
|
if self.gpu_chunk is not None: |
|
pass |
|
else: |
|
gpu_chunk = torch.empty(self.chunk_shape, device=self.device, dtype=self.chunk_dtype) |
|
gpu_chunk.copy_(self.cpu_chunk, non_blocking=True) |
|
self.gpu_chunk = gpu_chunk |
|
|
|
def get_gpu_chunk(self): |
|
assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device |
|
return self.gpu_chunk |
|
|
|
def check_gpu_chunk(self, ): |
|
assert (self.gpu_chunk is not None) and ( |
|
self.gpu_chunk.device == self.device |
|
), f"gpu_chunk {self.gpu_chunk is not None} shound be on {self.device}, but it is now on {self.gpu_chunk.device}" |
|
return True |
|
|
|
def offload(self): |
|
assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device |
|
del self.gpu_chunk |
|
self.gpu_chunk = None |
|
|
|
def overwrite_to_cpu(self): |
|
assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device |
|
self.cpu_chunk.copy_(self.gpu_chunk, non_blocking=True) |
|
|
|
|
|
class _FPDTGPUOffloadingAttentionImpl_(torch.autograd.Function): |
|
generate_vmap_rule = False |
|
|
|
@staticmethod |
|
def forward(ctx: Any, |
|
layernorm_output, |
|
attention_mask, |
|
inference_params, |
|
rotary_pos_emb, |
|
spg, |
|
scatter_idx, |
|
gather_idx, |
|
hidden_size, |
|
projection_size, |
|
hidden_size_per_attention_head, |
|
kv_projection_size, |
|
qkv_linear_weight, |
|
qkv_linear_bias, |
|
dropout, |
|
num_chunks=8, |
|
cpu_offloading=True): |
|
|
|
do_save = layernorm_output.requires_grad |
|
|
|
if rotary_pos_emb is not None: |
|
pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) |
|
ctx.pos_emb_cos = pos_emb_cos |
|
ctx.pos_emb_sin = pos_emb_sin |
|
else: |
|
ctx.pos_emb_cos = None |
|
ctx.pos_emb_sin = None |
|
with torch.no_grad(): |
|
per_gpu_seq_len = layernorm_output.shape[0] |
|
chunk_size = per_gpu_seq_len // num_chunks |
|
assert chunk_size * num_chunks == per_gpu_seq_len |
|
assert attention_mask is None |
|
ctx.num_chunks = num_chunks |
|
ctx.cpu_offloading = cpu_offloading |
|
ctx.spg = spg |
|
ctx.scatter_idx = scatter_idx |
|
ctx.gather_idx = gather_idx |
|
|
|
ctx.chunk_size = chunk_size |
|
device = get_accelerator().current_device_name() |
|
ctx.device = device |
|
ctx.dtype = layernorm_output.dtype |
|
ctx.projection_size = projection_size |
|
ctx.kv_projection_size = kv_projection_size |
|
|
|
global_q = [] |
|
global_k = [] |
|
global_v = [] |
|
|
|
ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) |
|
|
|
ctx.dropout_p = dropout |
|
ctx.window_size = (-1, -1) |
|
ctx.alibi_slopes = None |
|
|
|
batch_size = layernorm_output.shape[1] |
|
|
|
global_o = [] |
|
global_lse = [] |
|
|
|
layernorm_output_cpu = [] |
|
final_output = [] |
|
|
|
offload_stream = get_accelerator().Stream() |
|
general_offload_stream = get_accelerator().Stream() |
|
compute_stream = get_accelerator().default_stream() |
|
|
|
q_compute_chunk_idx = 0 |
|
kv_compute_chunk_idx = 0 |
|
for i in range(num_chunks): |
|
|
|
qkv_chunk = torch.matmul(layernorm_output[:chunk_size], |
|
qkv_linear_weight.t()) + qkv_linear_bias |
|
|
|
with get_accelerator().stream(general_offload_stream): |
|
layernorm_output_cpu.append(SequenceChunk(layernorm_output[:chunk_size])) |
|
|
|
layernorm_output = layernorm_output[chunk_size:] |
|
|
|
q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( |
|
qkv_chunk.shape[0], qkv_chunk.shape[1], -1, |
|
hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() |
|
q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) |
|
global_q_chunk_len = q_chunk.shape[1] |
|
|
|
k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( |
|
qkv_chunk.shape[0], qkv_chunk.shape[1], -1, |
|
hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() |
|
k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) |
|
|
|
v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( |
|
qkv_chunk.shape[0], qkv_chunk.shape[1], -1, |
|
hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() |
|
v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) |
|
|
|
dist.barrier() |
|
|
|
if ctx.pos_emb_cos is not None: |
|
pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] |
|
pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] |
|
|
|
q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) |
|
k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) |
|
|
|
compute_stream.wait_stream(offload_stream) |
|
compute_stream.synchronize() |
|
with get_accelerator().stream(offload_stream): |
|
global_q.append(SequenceChunk(q_chunk, is_in_use=True)) |
|
global_k.append(SequenceChunk(k_chunk, is_in_use=True)) |
|
global_v.append(SequenceChunk(v_chunk, is_in_use=True)) |
|
|
|
del qkv_chunk |
|
|
|
cur_attn_output = None |
|
cur_attn_lse = None |
|
for k_i in range(len(global_k)): |
|
causal_chunk = i == k_i |
|
with get_accelerator().stream(compute_stream): |
|
if flash_attn_version >= version.parse("2.6.0"): |
|
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( |
|
global_q[q_compute_chunk_idx].get_gpu_chunk(), |
|
global_k[kv_compute_chunk_idx].get_gpu_chunk(), |
|
global_v[kv_compute_chunk_idx].get_gpu_chunk(), |
|
ctx.dropout_p, |
|
ctx.softmax_scale, |
|
causal=causal_chunk, |
|
window_size=ctx.window_size, |
|
softcap=0.0, |
|
alibi_slopes=ctx.alibi_slopes, |
|
return_softmax=False) |
|
else: |
|
block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( |
|
global_q[q_compute_chunk_idx].get_gpu_chunk(), |
|
global_k[kv_compute_chunk_idx].get_gpu_chunk(), |
|
global_v[kv_compute_chunk_idx].get_gpu_chunk(), |
|
ctx.dropout_p, |
|
ctx.softmax_scale, |
|
causal=causal_chunk, |
|
window_size=ctx.window_size, |
|
alibi_slopes=ctx.alibi_slopes, |
|
return_softmax=False) |
|
cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, |
|
block_lse) |
|
|
|
can_offload_kv = True |
|
if k_i != (len(global_k) - 1) or i != (num_chunks - 1): |
|
if k_i != (len(global_k) - 1): |
|
next_kv_compute_chunk_idx = k_i + 1 |
|
else: |
|
next_kv_compute_chunk_idx = 0 |
|
|
|
if next_kv_compute_chunk_idx == kv_compute_chunk_idx: |
|
can_offload_kv = False |
|
else: |
|
if next_kv_compute_chunk_idx != (len(global_k) - 1): |
|
with get_accelerator().stream(offload_stream): |
|
global_k[next_kv_compute_chunk_idx].load_to_gpu() |
|
global_v[next_kv_compute_chunk_idx].load_to_gpu() |
|
|
|
if i == num_chunks - 1 and k_i == num_chunks - 1: |
|
with get_accelerator().stream(offload_stream): |
|
global_q[0].load_to_gpu() |
|
global_k[0].load_to_gpu() |
|
global_v[0].load_to_gpu() |
|
global_o[0].load_to_gpu() |
|
global_lse[0].load_to_gpu() |
|
|
|
compute_stream.wait_stream(offload_stream) |
|
compute_stream.synchronize() |
|
|
|
if can_offload_kv: |
|
global_k[kv_compute_chunk_idx].offload() |
|
global_v[kv_compute_chunk_idx].offload() |
|
kv_compute_chunk_idx = next_kv_compute_chunk_idx |
|
|
|
global_q[q_compute_chunk_idx].offload() |
|
q_compute_chunk_idx += 1 |
|
|
|
all2all_output = single_all_to_all( |
|
cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) |
|
final_output.append(all2all_output) |
|
with get_accelerator().stream(general_offload_stream): |
|
global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) |
|
global_lse.append(SequenceChunk(cur_attn_lse[:, :, :, 0].permute(0, 2, 1).contiguous())) |
|
|
|
compute_stream.wait_stream(general_offload_stream) |
|
compute_stream.synchronize() |
|
|
|
final_output = torch.cat(final_output, dim=1) |
|
|
|
head_dim = final_output.shape[-1] |
|
|
|
if do_save: |
|
ctx.layernorm_output = layernorm_output_cpu |
|
ctx.global_q = global_q |
|
ctx.global_k = global_k |
|
ctx.global_v = global_v |
|
ctx.attn_output = global_o |
|
ctx.attn_lse = global_lse |
|
ctx.head_dim = head_dim |
|
ctx.batch_size = batch_size |
|
|
|
ctx.qkv_linear_weight = qkv_linear_weight |
|
ctx.qkv_linear_bias = qkv_linear_bias |
|
|
|
return final_output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
num_chunks = ctx.num_chunks |
|
device = grad_output.device |
|
dtype = ctx.dtype |
|
spg = ctx.spg |
|
scatter_idx = ctx.scatter_idx |
|
gather_idx = ctx.gather_idx |
|
softmax_scale = ctx.softmax_scale |
|
dropout_p = ctx.dropout_p |
|
window_size = ctx.window_size |
|
alibi_slopes = ctx.alibi_slopes |
|
|
|
projection_size = ctx.projection_size |
|
kv_projection_size = ctx.kv_projection_size |
|
|
|
layernorm_output = ctx.layernorm_output |
|
|
|
global_q = ctx.global_q |
|
global_k = ctx.global_k |
|
global_v = ctx.global_v |
|
attn_output = ctx.attn_output |
|
lse = ctx.attn_lse |
|
|
|
qkv_linear_weight = ctx.qkv_linear_weight |
|
qkv_linear_bias = ctx.qkv_linear_bias |
|
|
|
offload_stream = get_accelerator().Stream() |
|
general_offload_stream = get_accelerator().Stream() |
|
compute_stream = get_accelerator().default_stream() |
|
|
|
chunk_size = grad_output.shape[1] // num_chunks |
|
assert chunk_size == layernorm_output[0].cpu_chunk.shape[0] |
|
|
|
grad_layernorm_output = [ |
|
torch.zeros(layernorm_output[0].chunk_shape, device=device, dtype=dtype) for _ in range(num_chunks) |
|
] |
|
|
|
grad_global_attn_output = [None for _ in range(num_chunks)] |
|
|
|
q_compute_chunk_idx = 0 |
|
kv_compute_chunk_idx = 0 |
|
last_q_accum_idx = 0 |
|
|
|
with get_accelerator().stream(general_offload_stream): |
|
layernorm_output[0].load_to_gpu() |
|
grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, |
|
device=qkv_linear_weight.device, |
|
dtype=torch.float) |
|
grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, |
|
device=qkv_linear_weight.device, |
|
dtype=torch.float) |
|
|
|
grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, |
|
gather_idx, 0, spg) |
|
get_accelerator().synchronize() |
|
grad_output = grad_output[:, chunk_size:] |
|
|
|
with get_accelerator().stream(offload_stream): |
|
grad_global_attn_output[0] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) |
|
dq = [ |
|
SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True) |
|
] + [ |
|
SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), |
|
device) for _ in range(num_chunks - 1) |
|
] |
|
dk_accum = torch.zeros(global_k[0].chunk_shape, dtype=torch.float, device=device) |
|
dv_accum = torch.zeros(global_v[0].chunk_shape, dtype=torch.float, device=device) |
|
|
|
for i in range(num_chunks): |
|
for q_i in range(num_chunks): |
|
no_computation = q_i < i |
|
if no_computation: |
|
continue |
|
|
|
causal_chunk = q_i == i |
|
|
|
dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) |
|
dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) |
|
dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) |
|
|
|
with get_accelerator().stream(compute_stream): |
|
if flash_attn_version >= version.parse("2.6.0"): |
|
_flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), |
|
global_q[q_compute_chunk_idx].get_gpu_chunk(), |
|
global_k[kv_compute_chunk_idx].get_gpu_chunk(), |
|
global_v[kv_compute_chunk_idx].get_gpu_chunk(), |
|
attn_output[q_compute_chunk_idx].get_gpu_chunk(), |
|
lse[q_compute_chunk_idx].get_gpu_chunk(), |
|
dq_this, |
|
dk_this, |
|
dv_this, |
|
dropout_p, |
|
softmax_scale, |
|
causal_chunk, |
|
window_size, |
|
softcap=0.0, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=False, |
|
rng_state=None) |
|
else: |
|
_flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), |
|
global_q[q_compute_chunk_idx].get_gpu_chunk(), |
|
global_k[kv_compute_chunk_idx].get_gpu_chunk(), |
|
global_v[kv_compute_chunk_idx].get_gpu_chunk(), |
|
attn_output[q_compute_chunk_idx].get_gpu_chunk(), |
|
lse[q_compute_chunk_idx].get_gpu_chunk(), |
|
dq_this, |
|
dk_this, |
|
dv_this, |
|
dropout_p, |
|
softmax_scale, |
|
causal_chunk, |
|
window_size, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=False, |
|
rng_state=None) |
|
|
|
if i != (len(global_k) - 1): |
|
if q_i != (len(global_q) - 1): |
|
next_q_compute_chunk_idx = q_i + 1 |
|
else: |
|
next_q_compute_chunk_idx = i + 1 |
|
|
|
can_offload_q = True |
|
|
|
if next_q_compute_chunk_idx == q_compute_chunk_idx: |
|
can_offload_q = False |
|
else: |
|
with get_accelerator().stream(offload_stream): |
|
if i > 0 or q_i > 0: |
|
if can_offload_q and last_q_accum_idx != i: |
|
dq[last_q_accum_idx].offload() |
|
dq[next_q_compute_chunk_idx].load_to_gpu() |
|
global_q[next_q_compute_chunk_idx].load_to_gpu() |
|
attn_output[next_q_compute_chunk_idx].load_to_gpu() |
|
lse[next_q_compute_chunk_idx].load_to_gpu() |
|
if grad_global_attn_output[next_q_compute_chunk_idx] is not None: |
|
grad_global_attn_output[next_q_compute_chunk_idx].load_to_gpu() |
|
|
|
if grad_global_attn_output[next_q_compute_chunk_idx] is None: |
|
grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), |
|
scatter_idx, gather_idx, 0, spg) |
|
dist.barrier() |
|
grad_output = grad_output[:, chunk_size:] |
|
grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk( |
|
grad_global_attn_output_chunk, is_in_use=True) |
|
|
|
compute_stream.wait_stream(offload_stream) |
|
compute_stream.synchronize() |
|
|
|
with get_accelerator().stream(compute_stream): |
|
dq[q_compute_chunk_idx].check_gpu_chunk() |
|
dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) |
|
dk_accum.add_(dk_this) |
|
dv_accum.add_(dv_this) |
|
|
|
offload_stream.wait_stream(compute_stream) |
|
with get_accelerator().stream(offload_stream): |
|
dq[q_compute_chunk_idx].overwrite_to_cpu() |
|
|
|
if can_offload_q: |
|
global_q[q_compute_chunk_idx].offload() |
|
attn_output[q_compute_chunk_idx].offload() |
|
lse[q_compute_chunk_idx].offload() |
|
grad_global_attn_output[q_compute_chunk_idx].offload() |
|
|
|
last_q_accum_idx = q_compute_chunk_idx |
|
q_compute_chunk_idx = next_q_compute_chunk_idx |
|
|
|
compute_stream.wait_stream(offload_stream) |
|
compute_stream.synchronize() |
|
|
|
dk_seq_len = dk_accum.shape[1] |
|
|
|
if ctx.pos_emb_cos is not None: |
|
dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), |
|
ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], |
|
ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) |
|
dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), |
|
ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], |
|
ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) |
|
else: |
|
dq_accum = dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype) |
|
dk_accum = dk_accum.to(dtype) |
|
dv_accum = dv_accum.to(dtype) |
|
|
|
dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, 0, spg) |
|
dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, 0, spg) |
|
dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, 0, spg) |
|
|
|
general_offload_stream.synchronize() |
|
compute_stream.wait_stream(general_offload_stream) |
|
dist.barrier() |
|
|
|
with get_accelerator().stream(compute_stream): |
|
input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) |
|
|
|
dq_accum = dq_accum.flatten(2).permute(1, 0, 2) |
|
dk_accum = dk_accum.flatten(2).permute(1, 0, 2) |
|
dv_accum = dv_accum.flatten(2).permute(1, 0, 2) |
|
|
|
l, b = dk_accum.shape[0], dk_accum.shape[1] |
|
|
|
grad_qkv_linear_weight[:projection_size].add_( |
|
torch.matmul(dq_accum.reshape(l * b, -1).t(), input_chunk)) |
|
grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( |
|
torch.matmul(dk_accum.reshape(l * b, -1).t(), input_chunk)) |
|
grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( |
|
torch.matmul(dv_accum.reshape(l * b, -1).t(), input_chunk)) |
|
|
|
grad_qkv_linear_bias[:projection_size].add_(dq_accum.sum(0).sum(0)) |
|
grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk_accum.sum(0).sum(0)) |
|
grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv_accum.sum(0).sum(0)) |
|
|
|
grad_layernorm_output[i].add_(torch.matmul(dq_accum, qkv_linear_weight[:projection_size])) |
|
grad_layernorm_output[i].add_( |
|
torch.matmul(dk_accum, qkv_linear_weight[projection_size:projection_size + kv_projection_size])) |
|
grad_layernorm_output[i].add_( |
|
torch.matmul(dv_accum, qkv_linear_weight[projection_size + kv_projection_size:])) |
|
|
|
del dq_accum, dk_accum, dv_accum |
|
dk_accum = torch.zeros(global_k[i].chunk_shape, dtype=torch.float, device=device) |
|
dv_accum = torch.zeros(global_v[i].chunk_shape, dtype=torch.float, device=device) |
|
dq[kv_compute_chunk_idx].offload() |
|
dq[kv_compute_chunk_idx] = None |
|
|
|
if i != (len(global_k) - 1): |
|
next_kv_compute_chunk_idx = kv_compute_chunk_idx + 1 |
|
with get_accelerator().stream(offload_stream): |
|
global_k[next_kv_compute_chunk_idx].load_to_gpu() |
|
global_v[next_kv_compute_chunk_idx].load_to_gpu() |
|
|
|
with get_accelerator().stream(general_offload_stream): |
|
layernorm_output[next_kv_compute_chunk_idx].load_to_gpu() |
|
|
|
compute_stream.wait_stream(offload_stream) |
|
compute_stream.synchronize() |
|
|
|
layernorm_output[kv_compute_chunk_idx].offload() |
|
global_k[kv_compute_chunk_idx].offload() |
|
global_v[kv_compute_chunk_idx].offload() |
|
kv_compute_chunk_idx = next_kv_compute_chunk_idx |
|
|
|
return torch.cat( |
|
grad_layernorm_output, |
|
dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( |
|
dtype), grad_qkv_linear_bias.to(dtype), None, None, None |
|
|
|
|
|
class FPDT_Attention(torch.nn.Module): |
|
|
|
def __init__(self, |
|
config, |
|
first_weight, |
|
first_bias, |
|
second_weight, |
|
second_bias, |
|
sequence_process_group, |
|
gather_idx: int = 0, |
|
scatter_idx: int = 2, |
|
return_bias=True, |
|
chunk_size=65536, |
|
enable_offloading=True) -> None: |
|
|
|
super(FPDT_Attention, self).__init__() |
|
if _flash_attn_forward is None or _flash_attn_backward is None: |
|
raise ImportError( |
|
"DeepSpeed FPDT requires flash-attn 2.6.3. Please install it with `pip install flash-attn --no-build-isolation`." |
|
) |
|
|
|
self.spg = sequence_process_group |
|
self.scatter_idx = scatter_idx |
|
self.gather_idx = gather_idx |
|
self.config = config |
|
|
|
self.projection_size = config.kv_channels * config.num_attention_heads |
|
self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads |
|
self.kv_projection_size = config.kv_channels * config.num_key_value_heads |
|
self.hidden_size = config.hidden_size |
|
|
|
self.qkv_linear_weight = first_weight |
|
self.qkv_linear_bias = first_bias |
|
self.qkv_dense_weight = second_weight |
|
self.qkv_dense_bias = second_bias |
|
|
|
self.reture_bias = return_bias |
|
self.dropout = config.attention_dropout |
|
|
|
self.chunk_size = chunk_size |
|
self.double_buffer = enable_offloading |
|
|
|
def forward(self, |
|
layernorm_output, |
|
attention_mask, |
|
inference_params, |
|
rotary_pos_emb, |
|
cpu_offloading=True) -> Tensor: |
|
self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size |
|
|
|
if not cpu_offloading or self.num_chunks_attn == 1: |
|
output = _FPDTGPUAttentionImpl_.apply(layernorm_output, attention_mask, inference_params, rotary_pos_emb, |
|
self.spg, self.scatter_idx, self.gather_idx, self.hidden_size, |
|
self.projection_size, self.hidden_size_per_attention_head, |
|
self.kv_projection_size, self.qkv_linear_weight, |
|
self.qkv_linear_bias, self.dropout, self.num_chunks_attn, |
|
cpu_offloading) |
|
else: |
|
output = _FPDTGPUOffloadingAttentionImpl_.apply( |
|
layernorm_output, attention_mask, inference_params, rotary_pos_emb, self.spg, self.scatter_idx, |
|
self.gather_idx, self.hidden_size, self.projection_size, self.hidden_size_per_attention_head, |
|
self.kv_projection_size, self.qkv_linear_weight, self.qkv_linear_bias, self.dropout, |
|
self.num_chunks_attn, cpu_offloading) |
|
|
|
output = output.flatten(2).permute(1, 0, 2).contiguous() |
|
|
|
output = torch.matmul(output, self.qkv_dense_weight.t()) |
|
if not self.reture_bias: |
|
output += self.qkv_dense_bias |
|
return output, self.qkv_dense_bias if self.reture_bias else None |
|
|
|
|
|
@torch.jit.script |
|
def bias_gelu(x): |
|
return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) |
|
|
|
|
|
@torch.jit.script |
|
def bias_gelu_back(g, x): |
|
tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
|
|
ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) |
|
return ff * g |
|
|
|
|
|
class FPDT_FFN(torch.autograd.Function): |
|
generate_vmap_rule = False |
|
|
|
@staticmethod |
|
def forward(ctx: Any, x, w1, b1, w2, b2, add_bias, chunk_size): |
|
do_save = x.requires_grad |
|
ctx.add_bias = add_bias |
|
device = x.device |
|
|
|
with torch.no_grad(): |
|
num_chunk = x.shape[0] // chunk_size |
|
ctx.num_chunk = num_chunk |
|
result = torch.empty(x.shape, device=device, dtype=x.dtype) |
|
assert chunk_size * num_chunk == x.shape[0] |
|
for i in range(num_chunk): |
|
st = i * chunk_size |
|
ed = st + chunk_size |
|
x_ = torch.matmul(x[st:ed], w1.t()) + b1 |
|
x_ = bias_gelu(x_) |
|
if add_bias: |
|
result[st:ed] = torch.matmul(x_, w2.t()) + b2 |
|
else: |
|
result[st:ed] = torch.matmul(x_, w2.t()) |
|
|
|
del x_ |
|
|
|
if do_save: |
|
ctx.device = device |
|
ctx.dtype = x.dtype |
|
ctx.save_for_backward(x, w1, b1, w2, b2) |
|
ctx.grad_x_shape = x.shape |
|
return result.to(x.dtype), b2 if not add_bias else None |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output, grad_bias): |
|
x, w1, b1, w2, b2 = ctx.saved_tensors |
|
device = ctx.device |
|
dtype = ctx.dtype |
|
add_bias = ctx.add_bias |
|
|
|
num_chunk = ctx.num_chunk |
|
chunk_size = x.shape[0] // num_chunk |
|
assert chunk_size * num_chunk == grad_output.shape[0] |
|
|
|
grad_w2 = torch.zeros(w2.shape, device=device, dtype=torch.float) |
|
grad_b2 = torch.zeros(b2.shape, device=device, dtype=torch.float) |
|
grad_w1 = torch.zeros(w1.shape, device=device, dtype=torch.float) |
|
grad_b1 = torch.zeros(b1.shape, device=device, dtype=torch.float) |
|
|
|
for i in range(num_chunk): |
|
st = i * chunk_size |
|
ed = st + chunk_size |
|
x_chunk = x[st:ed] |
|
|
|
before_act = (torch.matmul(x_chunk, w1.t()) + b1) |
|
before_act_2 = before_act**2 |
|
tanh_out = torch.tanh(0.79788456 * before_act * (1 + 0.044715 * before_act_2)) |
|
ff = 0.5 * before_act * ((1 - tanh_out * tanh_out) * |
|
(0.79788456 + 0.1070322243 * before_act_2)) + 0.5 * (1 + tanh_out) |
|
grad_w2.add_( |
|
torch.matmul(grad_output[st:ed].reshape(-1, grad_output.shape[2]).t(), |
|
(before_act * 0.5 * (1 + tanh_out)).reshape(-1, before_act.shape[2]))) |
|
del before_act, before_act_2, tanh_out |
|
|
|
grad_inter = torch.matmul(grad_output[st:ed], w2) * ff |
|
del ff |
|
|
|
grad_w1.add_(torch.matmul( |
|
grad_inter.reshape(-1, grad_inter.shape[2]).t(), x_chunk.reshape(-1, x.shape[2]))) |
|
grad_b1.add_(grad_inter.sum(0).sum(0)) |
|
|
|
x[st:ed].copy_(torch.matmul(grad_inter, w1)) |
|
|
|
del grad_inter |
|
|
|
if add_bias: |
|
grad_b2.add_(grad_output[st:ed].sum(0).sum(0)) |
|
|
|
return x, grad_w1.to(dtype), grad_b1.to(dtype), grad_w2.to(dtype), grad_b2.to(dtype), None, None |
|
|
|
|
|
class FPDT_LogitsLoss(torch.autograd.Function): |
|
generate_vmap_rule = False |
|
|
|
@staticmethod |
|
def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num_chunk): |
|
labels = labels.t() |
|
chunk_size = lm_output.shape[0] // num_chunk |
|
assert chunk_size * num_chunk == lm_output.shape[0] |
|
batch_size, local_seq_len = lm_output.shape[1], lm_output.shape[0] |
|
loss = torch.empty((batch_size, local_seq_len), dtype=torch.float, device=lm_output.device) |
|
|
|
ctx.num_chunk = num_chunk |
|
ctx.chunk_size = chunk_size |
|
ctx.device = lm_output.device |
|
ctx.dtype = lm_output.dtype |
|
|
|
ctx.rank = rank |
|
ctx.local_seq_len = local_seq_len |
|
with torch.no_grad(): |
|
for i in range(num_chunk): |
|
st = i * chunk_size |
|
ed = st + chunk_size |
|
logits_chunk = torch.matmul(lm_output[st:ed], logit_weights.t()).float() |
|
|
|
vocab_size = logits_chunk.size(2) |
|
|
|
softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) |
|
loss_chunk = torch.nn.functional.nll_loss(softmax.log().reshape(-1, vocab_size).contiguous(), |
|
labels[st:ed, :].reshape(-1).contiguous(), |
|
reduction='none') |
|
loss[:, st:ed] = loss_chunk.reshape(chunk_size, batch_size).t() |
|
|
|
del logits_chunk |
|
ctx.save_for_backward(lm_output.to('cpu'), labels) |
|
ctx.logit_weights = logit_weights |
|
|
|
seqlen = local_seq_len * spg_size |
|
batch_size = loss.size(0) |
|
loss = loss.t().contiguous() |
|
loss_all = torch.empty(seqlen, batch_size, dtype=loss.dtype, device=loss.device).contiguous() |
|
|
|
dist.allgather_fn(loss_all, loss, group=spg) |
|
|
|
return loss_all |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
lm_output, labels = ctx.saved_tensors |
|
logit_weights = ctx.logit_weights |
|
device = ctx.device |
|
dtype = ctx.dtype |
|
num_chunk = ctx.num_chunk |
|
chunk_size = ctx.chunk_size |
|
|
|
rank = ctx.rank |
|
local_seq_len = ctx.local_seq_len |
|
|
|
grad_output = grad_output[rank * local_seq_len:(rank + 1) * local_seq_len] |
|
grad_lm_output = [None for _ in range(num_chunk)] |
|
grad_logit_weights = torch.zeros(logit_weights.shape, device=grad_output.device, dtype=torch.float) |
|
for i in range(num_chunk): |
|
st = i * chunk_size |
|
ed = st + chunk_size |
|
lm_output_chunk = lm_output[st:ed].to(device) |
|
logits_chunk = torch.matmul(lm_output_chunk, logit_weights.t()).float() |
|
|
|
|
|
softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) |
|
vocab_size = logits_chunk.size(2) |
|
|
|
grad_input = softmax |
|
grad_2d = grad_input.reshape(-1, vocab_size).contiguous() |
|
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=device) |
|
|
|
grad_2d[arange_1d, labels[st:ed, :].reshape(-1).contiguous()] -= 1 |
|
grad_input.mul_(grad_output[:chunk_size, :].unsqueeze(dim=-1)) |
|
grad_input = grad_input.to(dtype) |
|
|
|
grad_output = grad_output[chunk_size:].contiguous() |
|
|
|
grad_lm_output_chunk = torch.matmul(grad_input, logit_weights) |
|
grad_lm_output[i] = grad_lm_output_chunk |
|
|
|
grad_logit_weights.add_( |
|
torch.matmul( |
|
grad_input.reshape(-1, grad_input.shape[2]).t(), |
|
lm_output_chunk.reshape(-1, lm_output_chunk.shape[2]))) |
|
|
|
return torch.cat(grad_lm_output, dim=0).to(dtype), None, grad_logit_weights.to(dtype), None, None, None, None |
|
|