# Copyright (c) Microsoft Corporation. # SPDX-License-Identifier: Apache-2.0 # DeepSpeed Team import torch from typing import Optional, Any, Tuple from torch import Tensor from packaging import version import deepspeed.comm as dist from deepspeed.accelerator import get_accelerator try: import flash_attn from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward flash_attn_version = version.parse(flash_attn.__version__) except ImportError: _flash_attn_forward = None _flash_attn_backward = None from einops import rearrange from .layer import single_all_to_all, apply_rotary_pos_emb def _rotate_half_backward(x): x = rearrange(x, '... (j d) -> ... j d', j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((x2, -x1), dim=-1) def apply_rotary_pos_emb_backward(grad_output, freqs_cos, freqs_sin): rot_dim = freqs_cos.shape[-1] grad, grad_pass = grad_output[..., :rot_dim], grad_output[..., rot_dim:] grad_t = (grad * freqs_cos) + (_rotate_half_backward(grad * freqs_sin)) grad = grad_t if grad_pass.shape[-1] == 0 else torch.cat((grad_t, grad_pass), dim=-1) return grad def _update_out_and_lse( out: torch.Tensor, lse: torch.Tensor, block_out: torch.Tensor, block_lse: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: block_out = block_out.to(torch.float32) block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) new_lse = lse + torch.log1p(torch.exp(block_lse - lse)) out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out lse = new_lse return out, lse def update_out_and_lse( out: Optional[torch.Tensor], lse: Optional[torch.Tensor], block_out: torch.Tensor, block_lse: torch.Tensor, slice_=None, ) -> Tuple[torch.Tensor, torch.Tensor]: if out is None: if slice_ is not None: raise RuntimeError("first update_out_and_lse should not pass slice_ args") out = block_out.to(torch.float32) lse = block_lse.permute(0, 2, 1).contiguous().unsqueeze(dim=-1).contiguous() elif slice_ is not None: slice_out, slice_lse = out[slice_], lse[slice_] slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse) out[slice_], lse[slice_] = slice_out, slice_lse else: out, lse = _update_out_and_lse(out, lse, block_out, block_lse) return out, lse class FPDT_InputConstruct(torch.nn.Module): def __init__(self, tokens, labels, loss_mask, attention_mask, position_ids, args, sp_size, sp_rank) -> None: super(FPDT_InputConstruct, self).__init__() self.tokens = tokens self.labels = labels self.loss_mask = loss_mask self.attention_mask = attention_mask self.position_ids = position_ids global_seq_len = tokens.shape[1] batch_size = tokens.shape[0] assert global_seq_len % sp_size == 0 assert global_seq_len % args.ds_sequence_parallel_fpdt_chunk_size == 0 num_chunk_per_gpu = global_seq_len // args.ds_sequence_parallel_fpdt_chunk_size local_seq_len = global_seq_len // sp_size assert local_seq_len % num_chunk_per_gpu == 0 self.num_chunk_per_gpu = num_chunk_per_gpu self.chunk_size = local_seq_len // num_chunk_per_gpu self.sp_size = sp_size self.sp_rank = sp_rank self.global_seq_len = global_seq_len self.local_seq_len = local_seq_len self.batch_size = batch_size self.device = tokens.device def generate(self): device = self.device totalChunks = self.global_seq_len // self.chunk_size token_chunk_idx = torch.arange(self.global_seq_len, device=device, dtype=torch.int) // self.chunk_size chunk_to_gpu = torch.arange(totalChunks, device=device, dtype=torch.int) chunk_to_gpu = chunk_to_gpu.reshape(self.num_chunk_per_gpu, -1).t().contiguous() gather_chunk = chunk_to_gpu.flatten().unsqueeze(1).contiguous() mask = gather_chunk == token_chunk_idx indices = mask.nonzero(as_tuple=False) gather_indices = indices[:, 0] token_chunk_indices = indices[:, 1] indices = torch.cat([token_chunk_indices[gather_indices == i] for i in range(gather_chunk.shape[0])]) load_balanced_loss_mask = self.loss_mask[:, indices] if self.loss_mask is not None else self.loss_mask indices = indices.reshape(-1, self.chunk_size)[self.num_chunk_per_gpu * self.sp_rank:self.num_chunk_per_gpu * (self.sp_rank + 1)].flatten().contiguous() load_balanced_tokens = self.tokens[:, indices] load_balanced_labels = self.labels[:, indices] if self.labels is not None else self.labels load_balanced_attention_mask = self.attention_mask if self.attention_mask is not None else self.attention_mask load_balanced_position_ids = self.position_ids[:, indices] if self.position_ids is not None else self.position_ids return load_balanced_tokens, load_balanced_labels, load_balanced_loss_mask, load_balanced_attention_mask, load_balanced_position_ids class _FPDTGPUAttentionImpl_(torch.autograd.Function): generate_vmap_rule = False @staticmethod def forward(ctx: Any, layernorm_output, attention_mask, inference_params, rotary_pos_emb, spg, scatter_idx, gather_idx, hidden_size, projection_size, hidden_size_per_attention_head, kv_projection_size, qkv_linear_weight, qkv_linear_bias, dropout, num_chunks=8, cpu_offloading=True): do_save = layernorm_output.requires_grad if rotary_pos_emb is not None: pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) ctx.pos_emb_cos = pos_emb_cos ctx.pos_emb_sin = pos_emb_sin else: ctx.pos_emb_cos = None ctx.pos_emb_sin = None with torch.no_grad(): per_gpu_seq_len = layernorm_output.shape[0] chunk_size = per_gpu_seq_len // num_chunks assert chunk_size * num_chunks == per_gpu_seq_len assert attention_mask is None ctx.num_chunks = num_chunks ctx.cpu_offloading = cpu_offloading ctx.spg = spg ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx device = get_accelerator().current_device_name() ctx.device = device ctx.dtype = layernorm_output.dtype ctx.projection_size = projection_size ctx.kv_projection_size = kv_projection_size global_q = [] global_k = [] global_v = [] ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) ctx.dropout_p = dropout ctx.window_size = (-1, -1) ctx.alibi_slopes = None batch_size = layernorm_output.shape[1] global_o = [None for _ in range(num_chunks)] global_lse = [None for _ in range(num_chunks)] for i in range(num_chunks): st = chunk_size * i ed = st + chunk_size qkv_chunk = torch.matmul(layernorm_output[st:ed], qkv_linear_weight.t()) + qkv_linear_bias q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) global_q_chunk_len = q_chunk.shape[1] if rotary_pos_emb is not None: q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_q.append(q_chunk) k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) if rotary_pos_emb is not None: k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)], pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)]) global_k.append(k_chunk) v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) global_v.append(v_chunk) for k_i in range(len(global_k)): causal_chunk = i == k_i if flash_attn_version >= version.parse("2.6.0"): block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], global_k[k_i], global_v[k_i], ctx.dropout_p, ctx.softmax_scale, causal=causal_chunk, window_size=ctx.window_size, softcap=0.0, alibi_slopes=ctx.alibi_slopes, return_softmax=False) else: block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward(global_q[i], global_k[k_i], global_v[k_i], ctx.dropout_p, ctx.softmax_scale, causal=causal_chunk, window_size=ctx.window_size, alibi_slopes=ctx.alibi_slopes, return_softmax=False) global_o[i], global_lse[i] = update_out_and_lse(global_o[i], global_lse[i], block_out, block_lse) global_o[i] = global_o[i].to(q_chunk.dtype) output = [None for i in range(num_chunks)] for i in range(num_chunks): global_lse[i] = global_lse[i][:, :, :, 0].permute(0, 2, 1).contiguous() output[i] = single_all_to_all(global_o[i].to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) output = torch.cat(output, dim=1) head_dim = output.shape[-1] if do_save: ctx.save_for_backward(layernorm_output) ctx.global_q = global_q ctx.global_k = global_k ctx.global_v = global_v ctx.attn_output = global_o ctx.attn_lse = global_lse ctx.head_dim = head_dim ctx.batch_size = batch_size ctx.qkv_linear_weight = qkv_linear_weight ctx.qkv_linear_bias = qkv_linear_bias return output @staticmethod def backward(ctx, grad_output): num_chunks = ctx.num_chunks device = ctx.device dtype = ctx.dtype spg = ctx.spg scatter_idx = ctx.scatter_idx gather_idx = ctx.gather_idx softmax_scale = ctx.softmax_scale dropout_p = ctx.dropout_p window_size = ctx.window_size alibi_slopes = ctx.alibi_slopes projection_size = ctx.projection_size kv_projection_size = ctx.kv_projection_size layernorm_output = ctx.saved_tensors[0] global_q = ctx.global_q global_k = ctx.global_k global_v = ctx.global_v attn_output = ctx.attn_output lse = ctx.attn_lse qkv_linear_weight = ctx.qkv_linear_weight qkv_linear_bias = ctx.qkv_linear_bias input_chunk_size = layernorm_output.shape[0] // num_chunks grad_layernorm_output = [ torch.zeros((input_chunk_size, layernorm_output.shape[1], layernorm_output.shape[2]), device=device, dtype=dtype) for _ in range(num_chunks) ] grad_global_attn_output = [] chunk_size = grad_output.shape[1] // num_chunks for i in range(num_chunks): st = chunk_size * i ed = st + chunk_size grad_global_attn_output.append( single_all_to_all(grad_output[:, st:ed].contiguous(), scatter_idx, gather_idx, 0, spg)) del grad_output dq = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] dk = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] dv = [torch.zeros(global_q[0].shape, dtype=torch.float, device=device) for _ in range(num_chunks)] grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) for i in range(num_chunks): k_chunk = global_k[i] v_chunk = global_v[i] for q_i in range(num_chunks): no_computation = q_i < i if no_computation: continue causal_chunk = q_i == i q_chunk = global_q[q_i] attn_output_chunk = attn_output[q_i] lse_chunk = lse[q_i] d_out = grad_global_attn_output[q_i] dq_this = torch.zeros(global_q[0].shape, dtype=dtype, device=device) dk_this = torch.zeros(global_k[0].shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].shape, dtype=dtype, device=device) if flash_attn_version >= version.parse("2.6.0"): _flash_attn_backward(d_out, q_chunk, k_chunk, v_chunk, attn_output_chunk, lse_chunk, dq_this, dk_this, dv_this, dropout_p, softmax_scale, causal_chunk, window_size, softcap=0.0, alibi_slopes=alibi_slopes, deterministic=False, rng_state=None) else: _flash_attn_backward(d_out, q_chunk, k_chunk, v_chunk, attn_output_chunk, lse_chunk, dq_this, dk_this, dv_this, dropout_p, softmax_scale, causal_chunk, window_size, alibi_slopes=alibi_slopes, deterministic=False, rng_state=None) dq[q_i].add_(dq_this.to(torch.float)) dk[i].add_(dk_this.to(torch.float)) dv[i].add_(dv_this.to(torch.float)) dk_seq_len = dk[i].shape[1] if ctx.pos_emb_cos is not None: dk[i] = apply_rotary_pos_emb_backward(dk[i].to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) else: dk[i] = dk[i].to(dtype) dv[i] = dv[i].to(dtype) dk[i] = single_all_to_all(dk[i].contiguous(), gather_idx, scatter_idx, 0, spg) dv[i] = single_all_to_all(dv[i].contiguous(), gather_idx, scatter_idx, 0, spg) input_st = i * input_chunk_size input_ed = input_st + input_chunk_size input_chunk = layernorm_output[input_st:input_ed].reshape(-1, layernorm_output.shape[-1]) dk[i] = dk[i].flatten(2).permute(1, 0, 2) dv[i] = dv[i].flatten(2).permute(1, 0, 2) l, b = dk[i].shape[0], dk[i].shape[1] grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( torch.matmul(dk[i].reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( torch.matmul(dv[i].reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk[i].sum(0).sum(0)) grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv[i].sum(0).sum(0)) grad_layernorm_output[i].add_( torch.matmul(dk[i], qkv_linear_weight[projection_size:projection_size + kv_projection_size])) grad_layernorm_output[i].add_(torch.matmul(dv[i], qkv_linear_weight[projection_size + kv_projection_size:])) dk[i] = None dv[i] = None for i in range(num_chunks): dq_seq_len = dq[i].shape[1] if ctx.pos_emb_cos is not None: dq[i] = apply_rotary_pos_emb_backward(dq[i].to(dtype), ctx.pos_emb_cos[:, dq_seq_len * i:dq_seq_len * (i + 1)], ctx.pos_emb_sin[:, dq_seq_len * i:dq_seq_len * (i + 1)]) else: dq[i] = dq[i].to(dtype) dq[i] = single_all_to_all(dq[i].to(dtype).contiguous(), gather_idx, scatter_idx, 0, spg) input_chunk = layernorm_output[:input_chunk_size].reshape(-1, layernorm_output.shape[-1]) layernorm_output = layernorm_output[input_chunk_size:] dq[i] = dq[i].flatten(2).permute(1, 0, 2) l, b = dq[i].shape[0], dq[i].shape[1] grad_qkv_linear_weight[:projection_size].add_(torch.matmul(dq[i].reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_bias[:projection_size].add_(dq[i].sum(0).sum(0)) grad_layernorm_output[i].add_(torch.matmul(dq[i], qkv_linear_weight[:projection_size])) dq[i] = None return torch.cat( grad_layernorm_output, dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( dtype), grad_qkv_linear_bias.to(dtype), None, None, None class SequenceChunk: def __init__(self, chunk: torch.Tensor, device=None, is_in_use=False): self.chunk_shape = chunk.shape self.chunk_dtype = chunk.dtype self.device = chunk.device if device is None else device cpu_chunk = torch.empty(chunk.shape, dtype=chunk.dtype, device='cpu', pin_memory=True) if get_accelerator().on_accelerator(chunk): cpu_chunk.copy_(chunk, non_blocking=True) else: cpu_chunk = chunk self.cpu_chunk = cpu_chunk self.gpu_chunk = chunk if is_in_use else None def load_to_gpu(self): assert self.gpu_chunk is None if self.gpu_chunk is not None: pass else: gpu_chunk = torch.empty(self.chunk_shape, device=self.device, dtype=self.chunk_dtype) gpu_chunk.copy_(self.cpu_chunk, non_blocking=True) self.gpu_chunk = gpu_chunk def get_gpu_chunk(self): assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device return self.gpu_chunk def check_gpu_chunk(self, ): assert (self.gpu_chunk is not None) and ( self.gpu_chunk.device == self.device ), f"gpu_chunk {self.gpu_chunk is not None} shound be on {self.device}, but it is now on {self.gpu_chunk.device}" return True def offload(self): assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device del self.gpu_chunk self.gpu_chunk = None def overwrite_to_cpu(self): assert self.gpu_chunk is not None and self.gpu_chunk.device == self.device self.cpu_chunk.copy_(self.gpu_chunk, non_blocking=True) class _FPDTGPUOffloadingAttentionImpl_(torch.autograd.Function): generate_vmap_rule = False @staticmethod def forward(ctx: Any, layernorm_output, attention_mask, inference_params, rotary_pos_emb, spg, scatter_idx, gather_idx, hidden_size, projection_size, hidden_size_per_attention_head, kv_projection_size, qkv_linear_weight, qkv_linear_bias, dropout, num_chunks=8, cpu_offloading=True): do_save = layernorm_output.requires_grad if rotary_pos_emb is not None: pos_emb_cos, pos_emb_sin = rotary_pos_emb[0].permute(1, 0, 2, 3), rotary_pos_emb[1].permute(1, 0, 2, 3) ctx.pos_emb_cos = pos_emb_cos ctx.pos_emb_sin = pos_emb_sin else: ctx.pos_emb_cos = None ctx.pos_emb_sin = None with torch.no_grad(): per_gpu_seq_len = layernorm_output.shape[0] chunk_size = per_gpu_seq_len // num_chunks assert chunk_size * num_chunks == per_gpu_seq_len assert attention_mask is None ctx.num_chunks = num_chunks ctx.cpu_offloading = cpu_offloading ctx.spg = spg ctx.scatter_idx = scatter_idx ctx.gather_idx = gather_idx ctx.chunk_size = chunk_size device = get_accelerator().current_device_name() ctx.device = device ctx.dtype = layernorm_output.dtype ctx.projection_size = projection_size ctx.kv_projection_size = kv_projection_size global_q = [] global_k = [] global_v = [] ctx.softmax_scale = hidden_size_per_attention_head**(-0.5) ctx.dropout_p = dropout ctx.window_size = (-1, -1) ctx.alibi_slopes = None batch_size = layernorm_output.shape[1] global_o = [] global_lse = [] layernorm_output_cpu = [] final_output = [] offload_stream = get_accelerator().Stream() general_offload_stream = get_accelerator().Stream() compute_stream = get_accelerator().default_stream() q_compute_chunk_idx = 0 kv_compute_chunk_idx = 0 for i in range(num_chunks): qkv_chunk = torch.matmul(layernorm_output[:chunk_size], qkv_linear_weight.t()) + qkv_linear_bias # torch.Size([18126, 1, 12288]) with get_accelerator().stream(general_offload_stream): layernorm_output_cpu.append(SequenceChunk(layernorm_output[:chunk_size])) layernorm_output = layernorm_output[chunk_size:] q_chunk = qkv_chunk[:, :, :projection_size].contiguous().reshape( qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd q_chunk = single_all_to_all(q_chunk, scatter_idx, gather_idx, 0, spg) global_q_chunk_len = q_chunk.shape[1] k_chunk = qkv_chunk[:, :, projection_size:projection_size + kv_projection_size].contiguous().reshape( qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd k_chunk = single_all_to_all(k_chunk, scatter_idx, gather_idx, 0, spg) v_chunk = qkv_chunk[:, :, projection_size + kv_projection_size:].contiguous().reshape( qkv_chunk.shape[0], qkv_chunk.shape[1], -1, hidden_size_per_attention_head).permute(1, 0, 2, 3).contiguous() # b, l, nh, hd v_chunk = single_all_to_all(v_chunk, scatter_idx, gather_idx, 0, spg) dist.barrier() if ctx.pos_emb_cos is not None: pos_emb_cos_chunk = pos_emb_cos[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] pos_emb_sin_chunk = pos_emb_sin[:, global_q_chunk_len * i:global_q_chunk_len * (i + 1)] q_chunk = apply_rotary_pos_emb(q_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) k_chunk = apply_rotary_pos_emb(k_chunk, pos_emb_cos_chunk, pos_emb_sin_chunk) compute_stream.wait_stream(offload_stream) compute_stream.synchronize() with get_accelerator().stream(offload_stream): global_q.append(SequenceChunk(q_chunk, is_in_use=True)) global_k.append(SequenceChunk(k_chunk, is_in_use=True)) global_v.append(SequenceChunk(v_chunk, is_in_use=True)) del qkv_chunk cur_attn_output = None cur_attn_lse = None for k_i in range(len(global_k)): causal_chunk = i == k_i with get_accelerator().stream(compute_stream): if flash_attn_version >= version.parse("2.6.0"): block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), global_v[kv_compute_chunk_idx].get_gpu_chunk(), ctx.dropout_p, ctx.softmax_scale, causal=causal_chunk, window_size=ctx.window_size, softcap=0.0, alibi_slopes=ctx.alibi_slopes, return_softmax=False) else: block_out, _, _, _, _, block_lse, _, _ = _flash_attn_forward( global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), global_v[kv_compute_chunk_idx].get_gpu_chunk(), ctx.dropout_p, ctx.softmax_scale, causal=causal_chunk, window_size=ctx.window_size, alibi_slopes=ctx.alibi_slopes, return_softmax=False) cur_attn_output, cur_attn_lse = update_out_and_lse(cur_attn_output, cur_attn_lse, block_out, block_lse) can_offload_kv = True if k_i != (len(global_k) - 1) or i != (num_chunks - 1): if k_i != (len(global_k) - 1): next_kv_compute_chunk_idx = k_i + 1 else: next_kv_compute_chunk_idx = 0 if next_kv_compute_chunk_idx == kv_compute_chunk_idx: can_offload_kv = False else: if next_kv_compute_chunk_idx != (len(global_k) - 1): with get_accelerator().stream(offload_stream): global_k[next_kv_compute_chunk_idx].load_to_gpu() global_v[next_kv_compute_chunk_idx].load_to_gpu() if i == num_chunks - 1 and k_i == num_chunks - 1: with get_accelerator().stream(offload_stream): global_q[0].load_to_gpu() global_k[0].load_to_gpu() global_v[0].load_to_gpu() global_o[0].load_to_gpu() global_lse[0].load_to_gpu() compute_stream.wait_stream(offload_stream) compute_stream.synchronize() if can_offload_kv: global_k[kv_compute_chunk_idx].offload() global_v[kv_compute_chunk_idx].offload() kv_compute_chunk_idx = next_kv_compute_chunk_idx global_q[q_compute_chunk_idx].offload() q_compute_chunk_idx += 1 all2all_output = single_all_to_all( cur_attn_output.to(ctx.dtype).contiguous(), gather_idx, scatter_idx, 0, spg) final_output.append(all2all_output) with get_accelerator().stream(general_offload_stream): global_o.append(SequenceChunk(cur_attn_output.to(ctx.dtype))) global_lse.append(SequenceChunk(cur_attn_lse[:, :, :, 0].permute(0, 2, 1).contiguous())) compute_stream.wait_stream(general_offload_stream) compute_stream.synchronize() final_output = torch.cat(final_output, dim=1) head_dim = final_output.shape[-1] if do_save: ctx.layernorm_output = layernorm_output_cpu ctx.global_q = global_q ctx.global_k = global_k ctx.global_v = global_v ctx.attn_output = global_o ctx.attn_lse = global_lse ctx.head_dim = head_dim ctx.batch_size = batch_size ctx.qkv_linear_weight = qkv_linear_weight ctx.qkv_linear_bias = qkv_linear_bias return final_output @staticmethod def backward(ctx, grad_output): num_chunks = ctx.num_chunks device = grad_output.device dtype = ctx.dtype spg = ctx.spg scatter_idx = ctx.scatter_idx gather_idx = ctx.gather_idx softmax_scale = ctx.softmax_scale dropout_p = ctx.dropout_p window_size = ctx.window_size alibi_slopes = ctx.alibi_slopes projection_size = ctx.projection_size kv_projection_size = ctx.kv_projection_size layernorm_output = ctx.layernorm_output global_q = ctx.global_q global_k = ctx.global_k global_v = ctx.global_v attn_output = ctx.attn_output lse = ctx.attn_lse qkv_linear_weight = ctx.qkv_linear_weight qkv_linear_bias = ctx.qkv_linear_bias offload_stream = get_accelerator().Stream() general_offload_stream = get_accelerator().Stream() compute_stream = get_accelerator().default_stream() chunk_size = grad_output.shape[1] // num_chunks assert chunk_size == layernorm_output[0].cpu_chunk.shape[0] grad_layernorm_output = [ torch.zeros(layernorm_output[0].chunk_shape, device=device, dtype=dtype) for _ in range(num_chunks) ] grad_global_attn_output = [None for _ in range(num_chunks)] q_compute_chunk_idx = 0 kv_compute_chunk_idx = 0 last_q_accum_idx = 0 with get_accelerator().stream(general_offload_stream): layernorm_output[0].load_to_gpu() grad_qkv_linear_weight = torch.zeros(qkv_linear_weight.shape, device=qkv_linear_weight.device, dtype=torch.float) grad_qkv_linear_bias = torch.zeros(qkv_linear_bias.shape, device=qkv_linear_weight.device, dtype=torch.float) grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, 0, spg) get_accelerator().synchronize() grad_output = grad_output[:, chunk_size:] with get_accelerator().stream(offload_stream): grad_global_attn_output[0] = SequenceChunk(grad_global_attn_output_chunk, is_in_use=True) dq = [ SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device=device), is_in_use=True) ] + [ SequenceChunk(torch.zeros(global_q[0].chunk_shape, dtype=torch.float, device='cpu', pin_memory=True), device) for _ in range(num_chunks - 1) ] dk_accum = torch.zeros(global_k[0].chunk_shape, dtype=torch.float, device=device) dv_accum = torch.zeros(global_v[0].chunk_shape, dtype=torch.float, device=device) for i in range(num_chunks): for q_i in range(num_chunks): no_computation = q_i < i if no_computation: continue causal_chunk = q_i == i dq_this = torch.zeros(global_q[0].chunk_shape, dtype=dtype, device=device) dk_this = torch.zeros(global_k[0].chunk_shape, dtype=dtype, device=device) dv_this = torch.zeros(global_v[0].chunk_shape, dtype=dtype, device=device) with get_accelerator().stream(compute_stream): if flash_attn_version >= version.parse("2.6.0"): _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), global_v[kv_compute_chunk_idx].get_gpu_chunk(), attn_output[q_compute_chunk_idx].get_gpu_chunk(), lse[q_compute_chunk_idx].get_gpu_chunk(), dq_this, dk_this, dv_this, dropout_p, softmax_scale, causal_chunk, window_size, softcap=0.0, alibi_slopes=alibi_slopes, deterministic=False, rng_state=None) else: _flash_attn_backward(grad_global_attn_output[q_compute_chunk_idx].get_gpu_chunk(), global_q[q_compute_chunk_idx].get_gpu_chunk(), global_k[kv_compute_chunk_idx].get_gpu_chunk(), global_v[kv_compute_chunk_idx].get_gpu_chunk(), attn_output[q_compute_chunk_idx].get_gpu_chunk(), lse[q_compute_chunk_idx].get_gpu_chunk(), dq_this, dk_this, dv_this, dropout_p, softmax_scale, causal_chunk, window_size, alibi_slopes=alibi_slopes, deterministic=False, rng_state=None) if i != (len(global_k) - 1): if q_i != (len(global_q) - 1): next_q_compute_chunk_idx = q_i + 1 else: next_q_compute_chunk_idx = i + 1 can_offload_q = True if next_q_compute_chunk_idx == q_compute_chunk_idx: can_offload_q = False else: with get_accelerator().stream(offload_stream): if i > 0 or q_i > 0: if can_offload_q and last_q_accum_idx != i: # the first q chunk calculate in the loop will be sent out, therefore we do not offload it dq[last_q_accum_idx].offload() dq[next_q_compute_chunk_idx].load_to_gpu() global_q[next_q_compute_chunk_idx].load_to_gpu() attn_output[next_q_compute_chunk_idx].load_to_gpu() lse[next_q_compute_chunk_idx].load_to_gpu() if grad_global_attn_output[next_q_compute_chunk_idx] is not None: grad_global_attn_output[next_q_compute_chunk_idx].load_to_gpu() if grad_global_attn_output[next_q_compute_chunk_idx] is None: grad_global_attn_output_chunk = single_all_to_all(grad_output[:, :chunk_size].contiguous(), scatter_idx, gather_idx, 0, spg) dist.barrier() grad_output = grad_output[:, chunk_size:] grad_global_attn_output[next_q_compute_chunk_idx] = SequenceChunk( grad_global_attn_output_chunk, is_in_use=True) compute_stream.wait_stream(offload_stream) compute_stream.synchronize() with get_accelerator().stream(compute_stream): dq[q_compute_chunk_idx].check_gpu_chunk() dq[q_compute_chunk_idx].gpu_chunk.add_(dq_this) dk_accum.add_(dk_this) dv_accum.add_(dv_this) offload_stream.wait_stream(compute_stream) with get_accelerator().stream(offload_stream): dq[q_compute_chunk_idx].overwrite_to_cpu() if can_offload_q: global_q[q_compute_chunk_idx].offload() attn_output[q_compute_chunk_idx].offload() lse[q_compute_chunk_idx].offload() grad_global_attn_output[q_compute_chunk_idx].offload() last_q_accum_idx = q_compute_chunk_idx q_compute_chunk_idx = next_q_compute_chunk_idx compute_stream.wait_stream(offload_stream) compute_stream.synchronize() dk_seq_len = dk_accum.shape[1] if ctx.pos_emb_cos is not None: dq_accum = apply_rotary_pos_emb_backward(dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) dk_accum = apply_rotary_pos_emb_backward(dk_accum.to(dtype), ctx.pos_emb_cos[:, dk_seq_len * i:dk_seq_len * (i + 1)], ctx.pos_emb_sin[:, dk_seq_len * i:dk_seq_len * (i + 1)]) else: dq_accum = dq[kv_compute_chunk_idx].get_gpu_chunk().to(dtype) dk_accum = dk_accum.to(dtype) dv_accum = dv_accum.to(dtype) dq_accum = single_all_to_all(dq_accum.contiguous(), gather_idx, scatter_idx, 0, spg) dk_accum = single_all_to_all(dk_accum.contiguous(), gather_idx, scatter_idx, 0, spg) dv_accum = single_all_to_all(dv_accum.contiguous(), gather_idx, scatter_idx, 0, spg) general_offload_stream.synchronize() compute_stream.wait_stream(general_offload_stream) dist.barrier() with get_accelerator().stream(compute_stream): input_chunk = layernorm_output[i].get_gpu_chunk().reshape(-1, layernorm_output[i].chunk_shape[-1]) dq_accum = dq_accum.flatten(2).permute(1, 0, 2) dk_accum = dk_accum.flatten(2).permute(1, 0, 2) dv_accum = dv_accum.flatten(2).permute(1, 0, 2) l, b = dk_accum.shape[0], dk_accum.shape[1] grad_qkv_linear_weight[:projection_size].add_( torch.matmul(dq_accum.reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_weight[projection_size:projection_size + kv_projection_size].add_( torch.matmul(dk_accum.reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_weight[projection_size + kv_projection_size:].add_( torch.matmul(dv_accum.reshape(l * b, -1).t(), input_chunk)) grad_qkv_linear_bias[:projection_size].add_(dq_accum.sum(0).sum(0)) grad_qkv_linear_bias[projection_size:projection_size + kv_projection_size].add_(dk_accum.sum(0).sum(0)) grad_qkv_linear_bias[projection_size + kv_projection_size:].add_(dv_accum.sum(0).sum(0)) grad_layernorm_output[i].add_(torch.matmul(dq_accum, qkv_linear_weight[:projection_size])) grad_layernorm_output[i].add_( torch.matmul(dk_accum, qkv_linear_weight[projection_size:projection_size + kv_projection_size])) grad_layernorm_output[i].add_( torch.matmul(dv_accum, qkv_linear_weight[projection_size + kv_projection_size:])) del dq_accum, dk_accum, dv_accum dk_accum = torch.zeros(global_k[i].chunk_shape, dtype=torch.float, device=device) dv_accum = torch.zeros(global_v[i].chunk_shape, dtype=torch.float, device=device) dq[kv_compute_chunk_idx].offload() dq[kv_compute_chunk_idx] = None if i != (len(global_k) - 1): next_kv_compute_chunk_idx = kv_compute_chunk_idx + 1 with get_accelerator().stream(offload_stream): global_k[next_kv_compute_chunk_idx].load_to_gpu() global_v[next_kv_compute_chunk_idx].load_to_gpu() with get_accelerator().stream(general_offload_stream): layernorm_output[next_kv_compute_chunk_idx].load_to_gpu() compute_stream.wait_stream(offload_stream) compute_stream.synchronize() layernorm_output[kv_compute_chunk_idx].offload() global_k[kv_compute_chunk_idx].offload() global_v[kv_compute_chunk_idx].offload() kv_compute_chunk_idx = next_kv_compute_chunk_idx return torch.cat( grad_layernorm_output, dim=0).to(dtype), None, None, None, None, None, None, None, None, None, None, grad_qkv_linear_weight.to( dtype), grad_qkv_linear_bias.to(dtype), None, None, None class FPDT_Attention(torch.nn.Module): def __init__(self, config, first_weight, first_bias, second_weight, second_bias, sequence_process_group, gather_idx: int = 0, scatter_idx: int = 2, return_bias=True, chunk_size=65536, enable_offloading=True) -> None: super(FPDT_Attention, self).__init__() if _flash_attn_forward is None or _flash_attn_backward is None: raise ImportError( "DeepSpeed FPDT requires flash-attn 2.6.3. Please install it with `pip install flash-attn --no-build-isolation`." ) self.spg = sequence_process_group self.scatter_idx = scatter_idx self.gather_idx = gather_idx self.config = config self.projection_size = config.kv_channels * config.num_attention_heads self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads self.kv_projection_size = config.kv_channels * config.num_key_value_heads self.hidden_size = config.hidden_size self.qkv_linear_weight = first_weight self.qkv_linear_bias = first_bias self.qkv_dense_weight = second_weight self.qkv_dense_bias = second_bias self.reture_bias = return_bias self.dropout = config.attention_dropout self.chunk_size = chunk_size self.double_buffer = enable_offloading def forward(self, layernorm_output, attention_mask, inference_params, rotary_pos_emb, cpu_offloading=True) -> Tensor: self.num_chunks_attn = layernorm_output.shape[0] * dist.get_world_size(self.spg) // self.chunk_size if not cpu_offloading or self.num_chunks_attn == 1: output = _FPDTGPUAttentionImpl_.apply(layernorm_output, attention_mask, inference_params, rotary_pos_emb, self.spg, self.scatter_idx, self.gather_idx, self.hidden_size, self.projection_size, self.hidden_size_per_attention_head, self.kv_projection_size, self.qkv_linear_weight, self.qkv_linear_bias, self.dropout, self.num_chunks_attn, cpu_offloading) else: output = _FPDTGPUOffloadingAttentionImpl_.apply( layernorm_output, attention_mask, inference_params, rotary_pos_emb, self.spg, self.scatter_idx, self.gather_idx, self.hidden_size, self.projection_size, self.hidden_size_per_attention_head, self.kv_projection_size, self.qkv_linear_weight, self.qkv_linear_bias, self.dropout, self.num_chunks_attn, cpu_offloading) output = output.flatten(2).permute(1, 0, 2).contiguous() output = torch.matmul(output, self.qkv_dense_weight.t()) if not self.reture_bias: output += self.qkv_dense_bias return output, self.qkv_dense_bias if self.reture_bias else None @torch.jit.script def bias_gelu(x): return x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x))) @torch.jit.script def bias_gelu_back(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) # sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243 ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) return ff * g class FPDT_FFN(torch.autograd.Function): generate_vmap_rule = False @staticmethod def forward(ctx: Any, x, w1, b1, w2, b2, add_bias, chunk_size): do_save = x.requires_grad ctx.add_bias = add_bias device = x.device with torch.no_grad(): num_chunk = x.shape[0] // chunk_size ctx.num_chunk = num_chunk result = torch.empty(x.shape, device=device, dtype=x.dtype) assert chunk_size * num_chunk == x.shape[0] for i in range(num_chunk): st = i * chunk_size ed = st + chunk_size x_ = torch.matmul(x[st:ed], w1.t()) + b1 x_ = bias_gelu(x_) if add_bias: result[st:ed] = torch.matmul(x_, w2.t()) + b2 else: result[st:ed] = torch.matmul(x_, w2.t()) del x_ if do_save: ctx.device = device ctx.dtype = x.dtype ctx.save_for_backward(x, w1, b1, w2, b2) ctx.grad_x_shape = x.shape return result.to(x.dtype), b2 if not add_bias else None @staticmethod def backward(ctx, grad_output, grad_bias): x, w1, b1, w2, b2 = ctx.saved_tensors device = ctx.device dtype = ctx.dtype add_bias = ctx.add_bias num_chunk = ctx.num_chunk chunk_size = x.shape[0] // num_chunk assert chunk_size * num_chunk == grad_output.shape[0] grad_w2 = torch.zeros(w2.shape, device=device, dtype=torch.float) grad_b2 = torch.zeros(b2.shape, device=device, dtype=torch.float) grad_w1 = torch.zeros(w1.shape, device=device, dtype=torch.float) grad_b1 = torch.zeros(b1.shape, device=device, dtype=torch.float) for i in range(num_chunk): st = i * chunk_size ed = st + chunk_size x_chunk = x[st:ed] before_act = (torch.matmul(x_chunk, w1.t()) + b1) before_act_2 = before_act**2 tanh_out = torch.tanh(0.79788456 * before_act * (1 + 0.044715 * before_act_2)) ff = 0.5 * before_act * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * before_act_2)) + 0.5 * (1 + tanh_out) grad_w2.add_( torch.matmul(grad_output[st:ed].reshape(-1, grad_output.shape[2]).t(), (before_act * 0.5 * (1 + tanh_out)).reshape(-1, before_act.shape[2]))) del before_act, before_act_2, tanh_out grad_inter = torch.matmul(grad_output[st:ed], w2) * ff del ff grad_w1.add_(torch.matmul( grad_inter.reshape(-1, grad_inter.shape[2]).t(), x_chunk.reshape(-1, x.shape[2]))) grad_b1.add_(grad_inter.sum(0).sum(0)) x[st:ed].copy_(torch.matmul(grad_inter, w1)) del grad_inter if add_bias: grad_b2.add_(grad_output[st:ed].sum(0).sum(0)) return x, grad_w1.to(dtype), grad_b1.to(dtype), grad_w2.to(dtype), grad_b2.to(dtype), None, None class FPDT_LogitsLoss(torch.autograd.Function): generate_vmap_rule = False @staticmethod def forward(ctx: Any, lm_output, labels, logit_weights, rank, spg_size, spg, num_chunk): labels = labels.t() chunk_size = lm_output.shape[0] // num_chunk assert chunk_size * num_chunk == lm_output.shape[0] batch_size, local_seq_len = lm_output.shape[1], lm_output.shape[0] loss = torch.empty((batch_size, local_seq_len), dtype=torch.float, device=lm_output.device) ctx.num_chunk = num_chunk ctx.chunk_size = chunk_size ctx.device = lm_output.device ctx.dtype = lm_output.dtype ctx.rank = rank ctx.local_seq_len = local_seq_len with torch.no_grad(): for i in range(num_chunk): st = i * chunk_size ed = st + chunk_size logits_chunk = torch.matmul(lm_output[st:ed], logit_weights.t()).float() vocab_size = logits_chunk.size(2) # nll softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) loss_chunk = torch.nn.functional.nll_loss(softmax.log().reshape(-1, vocab_size).contiguous(), labels[st:ed, :].reshape(-1).contiguous(), reduction='none') loss[:, st:ed] = loss_chunk.reshape(chunk_size, batch_size).t() del logits_chunk ctx.save_for_backward(lm_output.to('cpu'), labels) ctx.logit_weights = logit_weights seqlen = local_seq_len * spg_size batch_size = loss.size(0) loss = loss.t().contiguous() loss_all = torch.empty(seqlen, batch_size, dtype=loss.dtype, device=loss.device).contiguous() dist.allgather_fn(loss_all, loss, group=spg) return loss_all @staticmethod def backward(ctx, grad_output): lm_output, labels = ctx.saved_tensors logit_weights = ctx.logit_weights device = ctx.device dtype = ctx.dtype num_chunk = ctx.num_chunk chunk_size = ctx.chunk_size rank = ctx.rank local_seq_len = ctx.local_seq_len grad_output = grad_output[rank * local_seq_len:(rank + 1) * local_seq_len] grad_lm_output = [None for _ in range(num_chunk)] grad_logit_weights = torch.zeros(logit_weights.shape, device=grad_output.device, dtype=torch.float) for i in range(num_chunk): st = i * chunk_size ed = st + chunk_size lm_output_chunk = lm_output[st:ed].to(device) logits_chunk = torch.matmul(lm_output_chunk, logit_weights.t()).float() # nll softmax = torch.nn.functional.softmax(logits_chunk, dim=-1) vocab_size = logits_chunk.size(2) grad_input = softmax grad_2d = grad_input.reshape(-1, vocab_size).contiguous() arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=device) grad_2d[arange_1d, labels[st:ed, :].reshape(-1).contiguous()] -= 1 grad_input.mul_(grad_output[:chunk_size, :].unsqueeze(dim=-1)) grad_input = grad_input.to(dtype) grad_output = grad_output[chunk_size:].contiguous() grad_lm_output_chunk = torch.matmul(grad_input, logit_weights) grad_lm_output[i] = grad_lm_output_chunk grad_logit_weights.add_( torch.matmul( grad_input.reshape(-1, grad_input.shape[2]).t(), lm_output_chunk.reshape(-1, lm_output_chunk.shape[2]))) return torch.cat(grad_lm_output, dim=0).to(dtype), None, grad_logit_weights.to(dtype), None, None, None, None