|
import math |
|
|
|
import pytest |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from flash_attn import ( |
|
flash_attn_func, |
|
flash_attn_kvpacked_func, |
|
flash_attn_qkvpacked_func, |
|
flash_attn_varlen_func, |
|
flash_attn_varlen_kvpacked_func, |
|
flash_attn_varlen_qkvpacked_func, |
|
flash_attn_with_kvcache, |
|
) |
|
from flash_attn.bert_padding import pad_input, unpad_input |
|
from flash_attn.flash_attn_interface import _get_block_size_n |
|
from flash_attn.layers.rotary import apply_rotary_emb |
|
|
|
MAX_HEADDIM_SM8x = 192 |
|
|
|
|
|
is_sm75 = torch.cuda.get_device_capability("cuda") == (7, 5) |
|
is_sm8x = torch.cuda.get_device_capability("cuda")[0] == 8 |
|
is_sm80 = torch.cuda.get_device_capability("cuda") == (8, 0) |
|
is_sm90 = torch.cuda.get_device_capability("cuda") == (9, 0) |
|
|
|
|
|
def attn_bias_from_alibi_slopes( |
|
slopes, seqlen_q, seqlen_k, query_padding_mask=None, key_padding_mask=None, causal=False, key_leftpad=None |
|
): |
|
batch, nheads = slopes.shape |
|
device = slopes.device |
|
slopes = rearrange(slopes, "b h -> b h 1 1") |
|
if causal: |
|
return torch.arange(-seqlen_k + 1, 1, device=device, dtype=torch.float32) * slopes |
|
else: |
|
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") |
|
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) |
|
if key_leftpad is not None: |
|
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") |
|
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) |
|
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) |
|
sk = ( |
|
seqlen_k |
|
if key_padding_mask is None |
|
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") |
|
) |
|
sq = ( |
|
seqlen_q |
|
if query_padding_mask is None |
|
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") |
|
) |
|
relative_pos = torch.abs(row_idx + sk - sq - col_idx) |
|
return -slopes * relative_pos.to(dtype=slopes.dtype) |
|
|
|
|
|
def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): |
|
assert mode in ["full", "random", "third"] |
|
if mode == "full": |
|
lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) |
|
elif mode == "random": |
|
lengths = torch.randint( |
|
max(1, max_seqlen - 20), max_seqlen + 1, (batch_size, 1), device=device |
|
) |
|
elif mode == "third": |
|
lengths = torch.randint(max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device) |
|
padding_mask = ( |
|
repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths |
|
) |
|
return padding_mask |
|
|
|
|
|
def generate_qkv( |
|
q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False |
|
): |
|
""" |
|
Arguments: |
|
q: (batch_size, seqlen_q, nheads, d) |
|
k: (batch_size, seqlen_k, nheads_k, d) |
|
v: (batch_size, seqlen_k, nheads_k, d) |
|
query_padding_mask: (batch_size, seqlen), bool |
|
key_padding_mask: (batch_size, seqlen), bool |
|
""" |
|
assert not (kvpacked and qkvpacked) |
|
batch_size, seqlen_q, nheads, d = q.shape |
|
_, seqlen_k, nheads_k, _ = k.shape |
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d) |
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d) |
|
|
|
if query_padding_mask is not None: |
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, _ = unpad_input(q, query_padding_mask) |
|
output_pad_fn = lambda output_unpad: pad_input( |
|
output_unpad, indices_q, batch_size, seqlen_q |
|
) |
|
else: |
|
q_unpad = rearrange(q, "b s h d -> (b s) h d") |
|
cu_seqlens_q = torch.arange( |
|
0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device |
|
) |
|
max_seqlen_q = seqlen_q |
|
output_pad_fn = lambda output_unpad: rearrange( |
|
output_unpad, "(b s) h d -> b s h d", b=batch_size |
|
) |
|
|
|
if key_padding_mask is not None: |
|
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, _ = unpad_input(k, key_padding_mask) |
|
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask) |
|
else: |
|
k_unpad = rearrange(k, "b s h d -> (b s) h d") |
|
v_unpad = rearrange(v, "b s h d -> (b s) h d") |
|
cu_seqlens_k = torch.arange( |
|
0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device |
|
) |
|
max_seqlen_k = seqlen_k |
|
|
|
if qkvpacked: |
|
assert (query_padding_mask == key_padding_mask).all() |
|
assert nheads == nheads_k |
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) |
|
qkv = torch.stack([q, k, v], dim=2) |
|
if query_padding_mask is not None: |
|
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) |
|
else: |
|
dqkv_pad_fn = lambda dqkv_unpad: rearrange( |
|
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size |
|
) |
|
return ( |
|
qkv_unpad.detach().requires_grad_(), |
|
cu_seqlens_q, |
|
max_seqlen_q, |
|
qkv.detach().requires_grad_(), |
|
output_pad_fn, |
|
dqkv_pad_fn, |
|
) |
|
elif kvpacked: |
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) |
|
kv = torch.stack([k, v], dim=2) |
|
dq_pad_fn = output_pad_fn |
|
if key_padding_mask is not None: |
|
dkv_pad_fn = lambda dkv_unpad: pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) |
|
else: |
|
dkv_pad_fn = lambda dkv_unpad: rearrange( |
|
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size |
|
) |
|
return ( |
|
q_unpad.detach().requires_grad_(), |
|
kv_unpad.detach().requires_grad_(), |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q.detach().requires_grad_(), |
|
kv.detach().requires_grad_(), |
|
output_pad_fn, |
|
dq_pad_fn, |
|
dkv_pad_fn, |
|
) |
|
else: |
|
dq_pad_fn = output_pad_fn |
|
if key_padding_mask is not None: |
|
dk_pad_fn = lambda dk_unpad: pad_input(dk_unpad, indices_k, batch_size, seqlen_k) |
|
else: |
|
dk_pad_fn = lambda dk_unpad: rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) |
|
return ( |
|
q_unpad.detach().requires_grad_(), |
|
k_unpad.detach().requires_grad_(), |
|
v_unpad.detach().requires_grad_(), |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q.detach().requires_grad_(), |
|
k.detach().requires_grad_(), |
|
v.detach().requires_grad_(), |
|
output_pad_fn, |
|
dq_pad_fn, |
|
dk_pad_fn, |
|
) |
|
|
|
|
|
def construct_local_mask( |
|
seqlen_q, |
|
seqlen_k, |
|
window_size=(-1, -1), |
|
query_padding_mask=None, |
|
key_padding_mask=None, |
|
device=None, |
|
key_leftpad=None, |
|
): |
|
row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") |
|
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) |
|
if key_leftpad is not None: |
|
key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") |
|
col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) |
|
col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) |
|
sk = ( |
|
seqlen_k |
|
if key_padding_mask is None |
|
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") |
|
) |
|
sq = ( |
|
seqlen_q |
|
if query_padding_mask is None |
|
else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") |
|
) |
|
if window_size[0] < 0: |
|
return col_idx > row_idx + sk - sq + window_size[1] |
|
else: |
|
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk |
|
return torch.logical_or( |
|
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), |
|
col_idx < row_idx + sk - sq - window_size[0], |
|
) |
|
|
|
|
|
def attention_ref( |
|
q, |
|
k, |
|
v, |
|
query_padding_mask=None, |
|
key_padding_mask=None, |
|
attn_bias=None, |
|
dropout_p=0.0, |
|
dropout_mask=None, |
|
causal=False, |
|
window_size=(-1, -1), |
|
softcap=0.0, |
|
upcast=True, |
|
reorder_ops=False, |
|
key_leftpad=None, |
|
): |
|
""" |
|
Arguments: |
|
q: (batch_size, seqlen_q, nheads, head_dim) |
|
k: (batch_size, seqlen_k, nheads_k, head_dim) |
|
v: (batch_size, seqlen_k, nheads_k, head_dim) |
|
query_padding_mask: (batch_size, seqlen_q) |
|
key_padding_mask: (batch_size, seqlen_k) |
|
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) |
|
dropout_p: float |
|
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) |
|
causal: whether to apply causal masking |
|
window_size: (int, int), left and right window size |
|
upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast |
|
output back to fp16/bf16. |
|
reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) |
|
without changing the math. This is to estimate the numerical error from operation |
|
reordering. |
|
Output: |
|
output: (batch_size, seqlen_q, nheads, head_dim) |
|
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout |
|
""" |
|
if causal: |
|
window_size = (window_size[0], 0) |
|
dtype_og = q.dtype |
|
if upcast: |
|
q, k, v = q.float(), k.float(), v.float() |
|
seqlen_q, seqlen_k = q.shape[1], k.shape[1] |
|
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) |
|
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) |
|
d = q.shape[-1] |
|
if not reorder_ops: |
|
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) |
|
else: |
|
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) |
|
if softcap > 0: |
|
scores = scores / softcap |
|
scores = scores.tanh() |
|
scores = scores * softcap |
|
if key_padding_mask is not None: |
|
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) |
|
if window_size[0] >= 0 or window_size[1] >= 0: |
|
local_mask = construct_local_mask( |
|
seqlen_q, |
|
seqlen_k, |
|
window_size, |
|
query_padding_mask, |
|
key_padding_mask, |
|
q.device, |
|
key_leftpad=key_leftpad, |
|
) |
|
scores.masked_fill_(local_mask, float("-inf")) |
|
if attn_bias is not None: |
|
scores = scores + attn_bias |
|
attention = torch.softmax(scores, dim=-1).to(v.dtype) |
|
|
|
if window_size[0] >= 0 or window_size[1] >= 0: |
|
attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) |
|
|
|
|
|
if query_padding_mask is not None: |
|
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) |
|
dropout_scaling = 1.0 / (1 - dropout_p) |
|
|
|
|
|
if dropout_mask is not None: |
|
attention_drop = attention.masked_fill(~dropout_mask, 0.0) |
|
else: |
|
attention_drop = attention |
|
output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) |
|
if query_padding_mask is not None: |
|
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) |
|
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) |
|
|
|
|
|
def attention_kvpacked_ref( |
|
q, |
|
kv, |
|
query_padding_mask=None, |
|
key_padding_mask=None, |
|
attn_bias=None, |
|
dropout_p=0.0, |
|
dropout_mask=None, |
|
causal=False, |
|
window_size=(-1, -1), |
|
softcap=0.0, |
|
upcast=True, |
|
reorder_ops=False, |
|
key_leftpad=None, |
|
): |
|
return attention_ref( |
|
q, |
|
kv[:, :, 0], |
|
kv[:, :, 1], |
|
query_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
upcast=upcast, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
reorder_ops=reorder_ops, |
|
key_leftpad=key_leftpad, |
|
) |
|
|
|
|
|
def attention_qkvpacked_ref( |
|
qkv, |
|
key_padding_mask=None, |
|
attn_bias=None, |
|
dropout_p=0.0, |
|
dropout_mask=None, |
|
causal=False, |
|
window_size=(-1, -1), |
|
softcap=0.0, |
|
upcast=True, |
|
reorder_ops=False, |
|
): |
|
return attention_ref( |
|
qkv[:, :, 0], |
|
qkv[:, :, 1], |
|
qkv[:, :, 2], |
|
key_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
upcast=upcast, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
reorder_ops=reorder_ops, |
|
) |
|
|
|
|
|
def generate_sparsity_mask(seqlen, sparsity=0.3): |
|
repeats = seqlen // 16 // 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
nrow, ncol = seqlen // 16, seqlen // 256 |
|
mask = torch.rand(nrow, ncol, device="cuda") < sparsity |
|
return mask |
|
|
|
|
|
def attention_blocksparse_ref(qkv, blockmask, attn_mask, dropout_p, dropout_mask): |
|
""" |
|
Arguments: |
|
qkv: (batch_size, seqlen, 3, nheads, head_dim) |
|
blockmask: (seqlen / 16, seqlen / 256) |
|
attn_mask: (batch_size, seqlen) |
|
dropout_p: float |
|
dropout_mask: (batch_size, nheads, seqlen, seqlen) |
|
Output: |
|
output: (batch_size, seqlen, nheads, head_dim) |
|
attention: softmax after dropout |
|
""" |
|
q, k, v = qkv.float().unbind(dim=2) |
|
d = qkv.shape[-1] |
|
seqlen = qkv.shape[1] |
|
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) |
|
scores.masked_fill_(rearrange(~attn_mask, "b s -> b 1 1 s"), float("-inf")) |
|
blockmask = repeat(blockmask, "s_16 s_256 -> (s_16 16) (s_256 256)") |
|
blockmask = blockmask[:seqlen, :seqlen] |
|
scores.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), float("-inf")) |
|
attention = torch.softmax(scores, dim=-1) |
|
attention = attention.masked_fill(rearrange(~attn_mask, "b s -> b 1 s 1"), 0.0) |
|
attention = attention.masked_fill_(rearrange(~blockmask, "t s -> 1 1 t s"), 0.0) |
|
attention_drop = attention.masked_fill(~dropout_mask, 0.0) / (1 - dropout_p) |
|
output = torch.einsum("bhts,bshd->bthd", attention_drop, v) |
|
output.masked_fill_(rearrange(~attn_mask, "b s -> b s 1 1"), 0) |
|
return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype) |
|
|
|
|
|
def convert_flash_attn_S_to_softmax( |
|
S, |
|
seqlen_q, |
|
seqlen_k, |
|
query_padding_mask, |
|
key_padding_mask, |
|
head_dim, |
|
is_dropout, |
|
causal=False, |
|
window_size=(-1, -1), |
|
): |
|
"""FlashAttention stores the S matrix in a different way. |
|
Arguments: |
|
S: (batch_size, nheads, seqlen_q_rounded, seqlen_k_rounded) |
|
query_padding_mask: (batch_size, seqlen_q_rounded) |
|
key_padding_mask: (batch_size, seqlen_k_rounded) |
|
""" |
|
if causal: |
|
window_size = (window_size[0], 0) |
|
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:] |
|
S_converted = S |
|
if window_size[0] >= 0 or window_size[1] >= 0: |
|
local_mask = construct_local_mask( |
|
seqlen_q, |
|
seqlen_k, |
|
window_size, |
|
query_padding_mask, |
|
key_padding_mask, |
|
S.device, |
|
) |
|
local_mask = F.pad( |
|
local_mask, |
|
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q), |
|
value=True, |
|
) |
|
S_converted = S_converted.masked_fill(local_mask, 0.0) |
|
|
|
|
|
|
|
seqlen_q_og = ( |
|
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded |
|
) |
|
if query_padding_mask is not None: |
|
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og)) |
|
S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) |
|
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k |
|
if key_padding_mask is not None: |
|
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og)) |
|
S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) |
|
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded)) |
|
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded)) |
|
return S_converted[:, :, :seqlen_q, :seqlen_k] |
|
|
|
|
|
def normalize_flash_attn_S( |
|
attn_unnorm, |
|
q, |
|
k, |
|
v, |
|
query_padding_mask=None, |
|
key_padding_mask=None, |
|
attn_bias=None, |
|
is_dropout=False, |
|
causal=False, |
|
window_size=(-1, -1), |
|
): |
|
""" |
|
Arguments: |
|
q: (batch_size, seqlen_q, nheads, head_dim) |
|
k, v: (batch_size, seqlen_k, nheads, head_dim) |
|
key_padding_mask: (batch_size, seqlen_q) |
|
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) |
|
Output: |
|
softmax_lse: (batch_size, nheads, seqlen_q) |
|
softmax_max: (batch_size, nheads, seqlen_q) |
|
""" |
|
if causal: |
|
window_size = (window_size[0], 0) |
|
q, k, v = q.float(), k.float(), v.float() |
|
_, seqlen_q, _, head_dim = q.shape |
|
seqlen_k = k.shape[1] |
|
scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(head_dim), k) |
|
if key_padding_mask is not None: |
|
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) |
|
if window_size[0] >= 0 or window_size[1] >= 0: |
|
local_mask = construct_local_mask( |
|
seqlen_q, |
|
seqlen_k, |
|
window_size, |
|
query_padding_mask, |
|
key_padding_mask, |
|
q.device, |
|
) |
|
scores.masked_fill_(local_mask, float("-inf")) |
|
if attn_bias is not None: |
|
scores = scores + attn_bias.to(dtype=scores.dtype) |
|
block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal) |
|
scores_block = scores.split(block_size_n, dim=-1) |
|
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1) |
|
lse = torch.logsumexp(lse_block, dim=-1) |
|
|
|
|
|
lse[lse == float("-inf")] = float("inf") |
|
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1) |
|
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1) |
|
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1) |
|
attn_norm = torch.cat( |
|
[ |
|
a * rearrange(torch.exp(m - lse), "b h s -> b h s 1") |
|
for a, m in zip(attn_unnorm_block, cummax_block) |
|
], |
|
dim=-1, |
|
) |
|
if query_padding_mask is not None: |
|
attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) |
|
return attn_norm.to(dtype=attn_unnorm.dtype) |
|
|
|
|
|
def get_dropout_fraction( |
|
dropout_mask, |
|
query_padding_mask=None, |
|
key_padding_mask=None, |
|
causal=False, |
|
window_size=(-1, -1), |
|
): |
|
""" |
|
dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k), bool. True means keep, False means drop. |
|
query_padding_mask: (batch_size, seqlen_q) |
|
key_padding_mask: (batch_size, seqlen_k) |
|
""" |
|
if causal: |
|
window_size = (window_size[0], 0) |
|
batch_size, nheads, seqlen_q, seqlen_k = dropout_mask.shape |
|
dropped = ~dropout_mask |
|
valid = torch.ones_like(dropout_mask) |
|
if query_padding_mask is not None: |
|
dropped.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) |
|
valid.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), False) |
|
if key_padding_mask is not None: |
|
dropped.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) |
|
valid.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), False) |
|
if window_size[0] >= 0 or window_size[1] >= 0: |
|
local_mask = construct_local_mask( |
|
seqlen_q, |
|
seqlen_k, |
|
window_size, |
|
query_padding_mask, |
|
key_padding_mask, |
|
dropout_mask.device, |
|
) |
|
dropped.masked_fill_(local_mask, False) |
|
valid.masked_fill_(local_mask, False) |
|
dropped_total = dropped.sum() |
|
return dropped.sum() / valid.sum() |
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("deterministic", [False, True]) |
|
|
|
@pytest.mark.parametrize("alibi", [False, True]) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) |
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
|
|
|
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): |
|
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: |
|
pytest.skip() |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 4 |
|
nheads = 9 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) |
|
qkv = torch.randn( |
|
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
if alibi: |
|
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
|
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen, seqlen, causal=causal) |
|
else: |
|
alibi_slopes, attn_bias = None, None |
|
out, lse, S_dmask = flash_attn_qkvpacked_func( |
|
qkv, |
|
dropout_p, |
|
causal=causal, |
|
window_size=window_size, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=True, |
|
) |
|
if dropout_p > 0.0: |
|
S_dmask_converted = convert_flash_attn_S_to_softmax( |
|
S_dmask, |
|
seqlen, |
|
seqlen, |
|
None, |
|
None, |
|
d, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_mask = S_dmask_converted >= 0 |
|
attn_unnorm = S_dmask_converted.abs() |
|
attn = normalize_flash_attn_S( |
|
attn_unnorm, |
|
qkv[:, :, 0], |
|
qkv[:, :, 1], |
|
qkv[:, :, 2], |
|
None, |
|
None, |
|
attn_bias, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_fraction = get_dropout_fraction( |
|
dropout_mask, None, None, causal=causal, window_size=window_size |
|
).item() |
|
print(f"Actual dropout fraction: {dropout_fraction}") |
|
else: |
|
dropout_mask = None |
|
|
|
out_ref, attn_ref = attention_qkvpacked_ref( |
|
qkv, None, attn_bias, dropout_p, dropout_mask, causal=causal, window_size=window_size |
|
) |
|
out_pt, attn_pt = attention_qkvpacked_ref( |
|
qkv, |
|
None, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
if dropout_p > 0.0: |
|
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
|
|
|
g = torch.randn_like(out) |
|
|
|
|
|
|
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
(dqkv,) = torch.autograd.grad(out, qkv, g) |
|
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) |
|
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) |
|
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
|
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
|
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
|
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") |
|
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
|
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
|
|
|
if dropout_p > 0.0: |
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
|
|
|
if not alibi: |
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) |
|
|
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() |
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("deterministic", [False, True]) |
|
|
|
@pytest.mark.parametrize("alibi", [False, True]) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 128, 160, 192, 224, 256]) |
|
|
|
|
|
@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) |
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
|
|
|
def test_flash_attn_varlen_qkvpacked( |
|
seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype |
|
): |
|
if seqlen >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30: |
|
pytest.skip() |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 5 |
|
nheads = 6 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen, (2,)) |
|
qkv = torch.randn( |
|
batch_size, seqlen, 3, nheads, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
|
|
key_padding_mask = generate_random_padding_mask(seqlen, batch_size, device, mode="random") |
|
|
|
if alibi: |
|
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
|
attn_bias = attn_bias_from_alibi_slopes( |
|
alibi_slopes, seqlen, seqlen, key_padding_mask, key_padding_mask, causal=causal |
|
) |
|
else: |
|
alibi_slopes, attn_bias = None, None |
|
|
|
qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( |
|
*qkv.unbind(dim=2), key_padding_mask, key_padding_mask, qkvpacked=True |
|
) |
|
|
|
out_unpad, sm_lse, S_dmask = flash_attn_varlen_qkvpacked_func( |
|
qkv_unpad, |
|
cu_seqlens, |
|
max_seqlen, |
|
dropout_p, |
|
causal=causal, |
|
window_size=window_size, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=True, |
|
) |
|
out = output_pad_fn(out_unpad) |
|
if dropout_p > 0.0: |
|
S_dmask_converted = convert_flash_attn_S_to_softmax( |
|
S_dmask, |
|
seqlen, |
|
seqlen, |
|
key_padding_mask, |
|
key_padding_mask, |
|
d, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_mask = S_dmask_converted >= 0 |
|
attn_unnorm = S_dmask_converted.abs() |
|
attn = normalize_flash_attn_S( |
|
attn_unnorm, |
|
qkv[:, :, 0], |
|
qkv[:, :, 1], |
|
qkv[:, :, 2], |
|
key_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_fraction = get_dropout_fraction( |
|
dropout_mask, key_padding_mask, key_padding_mask, causal=causal, window_size=window_size |
|
).item() |
|
print(f"Actual dropout fraction: {dropout_fraction}") |
|
else: |
|
dropout_mask = None |
|
|
|
out_ref, attn_ref = attention_qkvpacked_ref( |
|
qkv, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
out_pt, attn_pt = attention_qkvpacked_ref( |
|
qkv, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
if dropout_p > 0.0: |
|
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
|
|
|
g = torch.randn_like(out) |
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g) |
|
dqkv = dqkv_pad_fn(dqkv_unpad) |
|
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g) |
|
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g) |
|
print(f"dQ max diff: {(dqkv[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
|
print(f"dK max diff: {(dqkv[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
|
print(f"dV max diff: {(dqkv[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
|
print(f"dQKV mean diff: {(dqkv - dqkv_ref).abs().mean().item()}") |
|
print(f"dQ Pytorch max diff: {(dqkv_pt[:, :, 0] - dqkv_ref[:, :, 0]).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(dqkv_pt[:, :, 1] - dqkv_ref[:, :, 1]).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(dqkv_pt[:, :, 2] - dqkv_ref[:, :, 2]).abs().max().item()}") |
|
print(f"dQKV Pytorch mean diff: {(dqkv_pt - dqkv_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
|
|
|
if dropout_p > 0.0: |
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
|
|
|
if not alibi: |
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) |
|
|
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item() |
|
|
|
|
|
@pytest.mark.parametrize("kvpacked", [True, False]) |
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
|
|
|
@pytest.mark.parametrize("deterministic", [False, True]) |
|
|
|
@pytest.mark.parametrize("alibi", [False, True]) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(113, 203), |
|
(128, 217), |
|
(113, 211), |
|
(108, 256), |
|
(256, 512), |
|
(512, 256), |
|
(1024, 1024), |
|
(1023, 1024), |
|
(1024, 1023), |
|
(2048, 2048), |
|
], |
|
) |
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
|
|
|
@pytest.mark.parametrize("softcap", [0.0, 50.0]) |
|
def test_flash_attn_output( |
|
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap |
|
): |
|
if ( |
|
max(seqlen_q, seqlen_k) >= 2048 |
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
|
): |
|
pytest.skip() |
|
if softcap > 0.0 and dropout_p > 0.0: |
|
pytest.skip("Softcap and dropout not supported together") |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 4 |
|
nheads = 6 if softcap == 0.0 else 4 |
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) |
|
assert nheads % nheads_k == 0 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
if softcap > 0: |
|
|
|
q = q * softcap |
|
if kvpacked: |
|
kv = torch.randn( |
|
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
else: |
|
k = torch.randn( |
|
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
v = torch.randn( |
|
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
if alibi: |
|
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
|
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) |
|
else: |
|
alibi_slopes, attn_bias = None, None |
|
|
|
if kvpacked: |
|
out, lse, S_dmask = flash_attn_kvpacked_func( |
|
q, |
|
kv, |
|
dropout_p, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=True, |
|
) |
|
else: |
|
out, lse, S_dmask = flash_attn_func( |
|
q, |
|
k, |
|
v, |
|
dropout_p, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=True, |
|
) |
|
if dropout_p > 0.0: |
|
S_dmask_converted = convert_flash_attn_S_to_softmax( |
|
S_dmask, |
|
seqlen_q, |
|
seqlen_k, |
|
None, |
|
None, |
|
d, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_mask = S_dmask_converted >= 0 |
|
attn_unnorm = S_dmask_converted.abs() |
|
if kvpacked: |
|
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) |
|
k_rep, v_rep = kv_rep.unbind(dim=2) |
|
else: |
|
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
|
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
|
attn = normalize_flash_attn_S( |
|
attn_unnorm, |
|
q, |
|
k_rep, |
|
v_rep, |
|
None, |
|
None, |
|
attn_bias, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_fraction = get_dropout_fraction( |
|
dropout_mask, None, None, causal=causal, window_size=window_size |
|
).item() |
|
print(f"Actual dropout fraction: {dropout_fraction}") |
|
else: |
|
dropout_mask = None |
|
|
|
if kvpacked: |
|
out_ref, attn_ref = attention_kvpacked_ref( |
|
q, |
|
kv, |
|
None, |
|
None, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
) |
|
out_pt, attn_pt = attention_kvpacked_ref( |
|
q, |
|
kv, |
|
None, |
|
None, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
else: |
|
out_ref, attn_ref = attention_ref( |
|
q, |
|
k, |
|
v, |
|
None, |
|
None, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
) |
|
out_pt, attn_pt = attention_ref( |
|
q, |
|
k, |
|
v, |
|
None, |
|
None, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
if dropout_p > 0.0: |
|
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
|
|
|
g = torch.randn_like(out) |
|
do_o = (g.float() * out.float()).sum(-1) |
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
if kvpacked: |
|
( |
|
dq, |
|
dkv, |
|
) = torch.autograd.grad(out, (q, kv), g) |
|
dk, dv = dkv.unbind(2) |
|
( |
|
dq_ref, |
|
dkv_ref, |
|
) = torch.autograd.grad(out_ref, (q, kv), g) |
|
dk_ref, dv_ref = dkv_ref.unbind(2) |
|
( |
|
dq_pt, |
|
dkv_pt, |
|
) = torch.autograd.grad(out_pt, (q, kv), g) |
|
dk_pt, dv_pt = dkv_pt.unbind(2) |
|
else: |
|
( |
|
dq, |
|
dk, |
|
dv, |
|
) = torch.autograd.grad(out, (q, k, v), g) |
|
( |
|
dq_ref, |
|
dk_ref, |
|
dv_ref, |
|
) = torch.autograd.grad(out_ref, (q, k, v), g) |
|
( |
|
dq_pt, |
|
dk_pt, |
|
dv_pt, |
|
) = torch.autograd.grad(out_pt, (q, k, v), g) |
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
|
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
|
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
|
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
|
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
|
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
|
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
|
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
|
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
|
|
|
if dropout_p > 0.0: |
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
|
|
|
if not alibi: |
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) |
|
|
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() |
|
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() |
|
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() |
|
|
|
|
|
@pytest.mark.parametrize("kvpacked", [True, False]) |
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
|
|
|
@pytest.mark.parametrize("deterministic", [False, True]) |
|
|
|
@pytest.mark.parametrize("alibi", [False, True]) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(1, 147), |
|
(113, 203), |
|
(128, 217), |
|
(113, 211), |
|
(108, 256), |
|
(256, 512), |
|
(512, 256), |
|
(1024, 1024), |
|
(1023, 1024), |
|
(1024, 1023), |
|
(2048, 2048), |
|
], |
|
) |
|
|
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
|
@pytest.mark.parametrize("softcap", [0.0, 50.0]) |
|
|
|
def test_flash_attn_varlen_output( |
|
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap |
|
): |
|
if ( |
|
max(seqlen_q, seqlen_k) >= 2048 |
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
|
): |
|
pytest.skip() |
|
if softcap > 0.0 and dropout_p > 0.0: |
|
pytest.skip("Softcap and dropout not supported together") |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 4 |
|
nheads = 6 if softcap == 0.0 else 4 |
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) |
|
assert nheads % nheads_k == 0 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
if softcap > 0: |
|
|
|
q = q * softcap |
|
|
|
if kvpacked: |
|
kv = torch.randn( |
|
batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
else: |
|
k = torch.randn( |
|
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
v = torch.randn( |
|
batch_size, seqlen_k, nheads_k, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
|
|
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") |
|
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") |
|
|
|
if alibi: |
|
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
|
attn_bias = attn_bias_from_alibi_slopes( |
|
alibi_slopes, seqlen_q, seqlen_k, query_padding_mask, key_padding_mask, causal=causal |
|
) |
|
else: |
|
alibi_slopes, attn_bias = None, None |
|
|
|
if kvpacked: |
|
( |
|
q_unpad, |
|
kv_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q, |
|
kv, |
|
output_pad_fn, |
|
dq_pad_fn, |
|
dkv_pad_fn, |
|
) = generate_qkv(q, *kv.unbind(dim=2), query_padding_mask, key_padding_mask, kvpacked=True) |
|
out_unpad, sm_lse, S_dmask = flash_attn_varlen_kvpacked_func( |
|
q_unpad, |
|
kv_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
dropout_p, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=True, |
|
) |
|
else: |
|
( |
|
q_unpad, |
|
k_unpad, |
|
v_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q, |
|
k, |
|
v, |
|
output_pad_fn, |
|
dq_pad_fn, |
|
dk_pad_fn, |
|
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) |
|
out_unpad, sm_lse, S_dmask = flash_attn_varlen_func( |
|
q_unpad, |
|
k_unpad, |
|
v_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
dropout_p, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=True, |
|
) |
|
out = output_pad_fn(out_unpad) |
|
if dropout_p > 0.0: |
|
S_dmask_converted = convert_flash_attn_S_to_softmax( |
|
S_dmask, |
|
seqlen_q, |
|
seqlen_k, |
|
query_padding_mask, |
|
key_padding_mask, |
|
d, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_mask = S_dmask_converted >= 0 |
|
attn_unnorm = S_dmask_converted.abs() |
|
if kvpacked: |
|
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) |
|
k_rep, v_rep = kv_rep.unbind(dim=2) |
|
else: |
|
k_rep = repeat(k, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
|
v_rep = repeat(v, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
|
attn = normalize_flash_attn_S( |
|
attn_unnorm, |
|
q, |
|
k_rep, |
|
v_rep, |
|
query_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p > 0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
dropout_fraction = get_dropout_fraction( |
|
dropout_mask, |
|
query_padding_mask, |
|
key_padding_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
).item() |
|
print(f"Actual dropout fraction: {dropout_fraction}") |
|
else: |
|
dropout_mask = None |
|
|
|
if kvpacked: |
|
out_ref, attn_ref = attention_kvpacked_ref( |
|
q, |
|
kv, |
|
query_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
) |
|
out_pt, attn_pt = attention_kvpacked_ref( |
|
q, |
|
kv, |
|
query_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
else: |
|
out_ref, attn_ref = attention_ref( |
|
q, |
|
k, |
|
v, |
|
query_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
) |
|
out_pt, attn_pt = attention_ref( |
|
q, |
|
k, |
|
v, |
|
query_padding_mask, |
|
key_padding_mask, |
|
attn_bias, |
|
dropout_p, |
|
dropout_mask, |
|
causal=causal, |
|
window_size=window_size, |
|
softcap=softcap, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
if dropout_p > 0.0: |
|
print(f"Attention max diff: {(attn - attn_ref).abs().max().item()}") |
|
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") |
|
|
|
g = torch.randn_like(out) |
|
if ((d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90)): |
|
if kvpacked: |
|
( |
|
dq_unpad, |
|
dkv_unpad, |
|
) = torch.autograd.grad(out, (q_unpad, kv_unpad), g) |
|
dk, dv = dkv_pad_fn(dkv_unpad).unbind(2) |
|
( |
|
dq_ref, |
|
dkv_ref, |
|
) = torch.autograd.grad(out_ref, (q, kv), g) |
|
dk_ref, dv_ref = dkv_ref.unbind(2) |
|
( |
|
dq_pt, |
|
dkv_pt, |
|
) = torch.autograd.grad(out_pt, (q, kv), g) |
|
dk_pt, dv_pt = dkv_pt.unbind(2) |
|
else: |
|
( |
|
dq_unpad, |
|
dk_unpad, |
|
dv_unpad, |
|
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) |
|
dk = dk_pad_fn(dk_unpad) |
|
dv = dk_pad_fn(dv_unpad) |
|
( |
|
dq_ref, |
|
dk_ref, |
|
dv_ref, |
|
) = torch.autograd.grad(out_ref, (q, k, v), g) |
|
( |
|
dq_pt, |
|
dk_pt, |
|
dv_pt, |
|
) = torch.autograd.grad(out_pt, (q, k, v), g) |
|
dq = dq_pad_fn(dq_unpad) |
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
|
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
|
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
|
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
|
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
|
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
|
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
|
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
|
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
|
|
|
if dropout_p > 0.0: |
|
assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() |
|
|
|
if not alibi: |
|
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) |
|
|
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() |
|
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() |
|
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() |
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("swap_sq_sk", [False, True]) |
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(1, 239), |
|
(3, 799), |
|
(127, 512), |
|
(127, 513), |
|
(113, 203), |
|
(128, 217), |
|
(113, 211), |
|
(108, 256), |
|
(256, 512), |
|
(1023, 1024), |
|
], |
|
) |
|
|
|
def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype): |
|
if ( |
|
max(seqlen_q, seqlen_k) >= 2048 |
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
|
): |
|
pytest.skip() |
|
if swap_sq_sk: |
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q |
|
device = "cuda" |
|
causal = True |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 8 |
|
nheads = 9 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size) |
|
out_ref, attn_ref = attention_ref( |
|
q, k, v, None, None, None, 0.0, None, causal=causal, window_size=window_size |
|
) |
|
out_pt, attn_pt = attention_ref( |
|
q, |
|
k, |
|
v, |
|
None, |
|
None, |
|
None, |
|
0.0, |
|
None, |
|
causal=causal, |
|
window_size=window_size, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
|
|
g = torch.randn_like(out) |
|
do_o = (g.float() * out.float()).sum(-1) |
|
( |
|
dq, |
|
dk, |
|
dv, |
|
) = torch.autograd.grad(out, (q, k, v), g) |
|
( |
|
dq_ref, |
|
dk_ref, |
|
dv_ref, |
|
) = torch.autograd.grad(out_ref, (q, k, v), g) |
|
( |
|
dq_pt, |
|
dk_pt, |
|
dv_pt, |
|
) = torch.autograd.grad(out_pt, (q, k, v), g) |
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
|
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
|
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
|
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
|
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
|
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
|
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
|
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
|
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 |
|
|
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 |
|
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 |
|
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 |
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("swap_sq_sk", [False, True]) |
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(1, 239), |
|
(3, 799), |
|
(127, 512), |
|
(127, 513), |
|
(113, 203), |
|
(128, 217), |
|
(113, 211), |
|
(108, 256), |
|
(256, 512), |
|
(1023, 1024), |
|
], |
|
) |
|
|
|
@pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) |
|
|
|
def test_flash_attn_varlen_causal( |
|
seqlen_q, seqlen_k, swap_sq_sk, d, local, paged_kv_block_size, dtype |
|
): |
|
if ( |
|
max(seqlen_q, seqlen_k) >= 2048 |
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
|
): |
|
pytest.skip() |
|
if swap_sq_sk: |
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q |
|
device = "cuda" |
|
causal = True |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 8 |
|
nheads = 9 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
|
|
if paged_kv_block_size is None: |
|
k = torch.randn( |
|
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
v = torch.randn( |
|
batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True |
|
) |
|
block_table = None |
|
else: |
|
k, v, block_table, k_cache_paged, v_cache_paged, num_blocks = _generate_block_kvcache( |
|
seqlen_k, paged_kv_block_size, batch_size, nheads, d, device, dtype |
|
) |
|
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") |
|
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") |
|
( |
|
q_unpad, |
|
k_unpad, |
|
v_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q, |
|
k, |
|
v, |
|
output_pad_fn, |
|
dq_pad_fn, |
|
dk_pad_fn, |
|
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) |
|
out_unpad = flash_attn_varlen_func( |
|
q_unpad, |
|
k_unpad if paged_kv_block_size is None else k_cache_paged, |
|
v_unpad if paged_kv_block_size is None else v_cache_paged, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
block_table=block_table, |
|
) |
|
out = output_pad_fn(out_unpad) |
|
out_ref, attn_ref = attention_ref( |
|
q, |
|
k, |
|
v, |
|
query_padding_mask, |
|
key_padding_mask, |
|
None, |
|
0.0, |
|
None, |
|
causal=causal, |
|
window_size=window_size, |
|
) |
|
out_pt, attn_pt = attention_ref( |
|
q, |
|
k, |
|
v, |
|
query_padding_mask, |
|
key_padding_mask, |
|
None, |
|
0.0, |
|
None, |
|
causal=causal, |
|
window_size=window_size, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
|
|
g = torch.randn_like(out) |
|
do_o = (g.float() * out.float()).sum(-1) |
|
test_backward = block_table is None |
|
if test_backward: |
|
( |
|
dq_unpad, |
|
dk_unpad, |
|
dv_unpad, |
|
) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g) |
|
dq = dq_pad_fn(dq_unpad) |
|
dk = dk_pad_fn(dk_unpad) |
|
dv = dk_pad_fn(dv_unpad) |
|
( |
|
dq_ref, |
|
dk_ref, |
|
dv_ref, |
|
) = torch.autograd.grad(out_ref, (q, k, v), g) |
|
( |
|
dq_pt, |
|
dk_pt, |
|
dv_pt, |
|
) = torch.autograd.grad(out_pt, (q, k, v), g) |
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
|
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
|
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
|
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
|
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
|
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
|
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
|
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
|
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 |
|
|
|
if test_backward: |
|
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5 |
|
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5 |
|
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5 |
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("deterministic", [False, True]) |
|
|
|
@pytest.mark.parametrize("alibi", [False, True]) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("swap_sq_sk", [False, True]) |
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(3, 1024), |
|
(1, 339), |
|
(64, 800), |
|
(3, 799), |
|
(64, 2048), |
|
(16, 20000), |
|
(16, 100000), |
|
(128, 128), |
|
(256, 256), |
|
], |
|
) |
|
|
|
def test_flash_attn_splitkv( |
|
seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, alibi, deterministic, dtype |
|
): |
|
if swap_sq_sk: |
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 1 |
|
nheads = 12 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
if alibi: |
|
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
|
attn_bias = attn_bias_from_alibi_slopes(alibi_slopes, seqlen_q, seqlen_k, causal=causal) |
|
else: |
|
alibi_slopes, attn_bias = None, None |
|
out, lse, _ = flash_attn_func( |
|
q, |
|
k, |
|
v, |
|
0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
alibi_slopes=alibi_slopes, |
|
deterministic=deterministic, |
|
return_attn_probs=True, |
|
) |
|
out_ref, attn_ref = attention_ref( |
|
q, k, v, None, None, attn_bias, 0.0, None, causal=causal, window_size=window_size |
|
) |
|
out_pt, attn_pt = attention_ref( |
|
q, |
|
k, |
|
v, |
|
None, |
|
None, |
|
attn_bias, |
|
0.0, |
|
None, |
|
causal=causal, |
|
window_size=window_size, |
|
upcast=False, |
|
reorder_ops=True, |
|
) |
|
|
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
|
|
g = torch.randn_like(out) |
|
do_o = (g.float() * out.float()).sum(-1) |
|
( |
|
dq, |
|
dk, |
|
dv, |
|
) = torch.autograd.grad(out, (q, k, v), g) |
|
( |
|
dq_ref, |
|
dk_ref, |
|
dv_ref, |
|
) = torch.autograd.grad(out_ref, (q, k, v), g) |
|
( |
|
dq_pt, |
|
dk_pt, |
|
dv_pt, |
|
) = torch.autograd.grad(out_pt, (q, k, v), g) |
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") |
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") |
|
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") |
|
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") |
|
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") |
|
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") |
|
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") |
|
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") |
|
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") |
|
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5 |
|
|
|
mult = 2 if not alibi else 8 |
|
assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4 |
|
assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4 |
|
assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 |
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16]) |
|
@pytest.mark.parametrize("num_splits", [1, 0]) |
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
|
|
|
@pytest.mark.parametrize("new_kv", [False, True]) |
|
|
|
@pytest.mark.parametrize("alibi", [False, True]) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) |
|
|
|
@pytest.mark.parametrize("rotary_interleaved", [False, True]) |
|
|
|
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) |
|
|
|
@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) |
|
|
|
|
|
@pytest.mark.parametrize("has_leftpad", [False, True]) |
|
|
|
|
|
@pytest.mark.parametrize("has_batch_idx", [False]) |
|
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) |
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(1, 128), |
|
(1, 339), |
|
(3, 1024), |
|
(64, 800), |
|
(64, 256), |
|
(3, 799), |
|
(64, 2048), |
|
(16, 20000), |
|
(1, 128 * 1024), |
|
(16, 128 * 1024), |
|
(128, 128), |
|
], |
|
) |
|
|
|
def test_flash_attn_kvcache( |
|
seqlen_q, |
|
seqlen_k, |
|
d, |
|
has_batch_idx, |
|
has_leftpad, |
|
paged_kv_block_size, |
|
rotary_fraction, |
|
rotary_interleaved, |
|
seqlen_new_eq_seqlen_q, |
|
causal, |
|
local, |
|
alibi, |
|
new_kv, |
|
mha_type, |
|
num_splits, |
|
dtype, |
|
): |
|
if seqlen_q > seqlen_k and new_kv: |
|
pytest.skip() |
|
if not new_kv and rotary_fraction > 0.0: |
|
pytest.skip() |
|
if has_batch_idx and paged_kv_block_size is not None: |
|
pytest.skip() |
|
if has_leftpad and paged_kv_block_size is not None: |
|
pytest.skip() |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 2 |
|
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 |
|
nheads = 6 |
|
|
|
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 |
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) |
|
assert nheads % nheads_k == 0 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) |
|
seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() |
|
if new_kv: |
|
k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) |
|
v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) |
|
else: |
|
k, v = None, None |
|
if paged_kv_block_size is None: |
|
k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) |
|
v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) |
|
block_table = None |
|
else: |
|
( |
|
k_cache, |
|
v_cache, |
|
block_table, |
|
k_cache_paged, |
|
v_cache_paged, |
|
num_blocks, |
|
) = _generate_block_kvcache( |
|
seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype |
|
) |
|
cache_seqlens = torch.randint( |
|
0 if new_kv else 1, |
|
|
|
( |
|
(seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) |
|
if new_kv |
|
else (seqlen_k + 1) |
|
), |
|
(batch_size,), |
|
dtype=torch.int32, |
|
device=device, |
|
) |
|
if has_leftpad: |
|
cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) |
|
if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) |
|
for i in range(batch_size)]) |
|
else: |
|
cache_leftpad = None |
|
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") |
|
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") |
|
key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) |
|
if has_leftpad: |
|
key_padding_mask = torch.logical_and( |
|
key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) |
|
) |
|
if has_batch_idx: |
|
cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ |
|
:batch_size |
|
] |
|
else: |
|
cache_batch_idx = None |
|
if alibi: |
|
alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
|
attn_bias = attn_bias_from_alibi_slopes( |
|
alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad |
|
) |
|
else: |
|
alibi_slopes, attn_bias = None, None |
|
|
|
if rotary_dim > 0: |
|
angle = ( |
|
torch.rand( |
|
seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, |
|
rotary_dim // 2, |
|
device=device, |
|
) |
|
* 2 |
|
* math.pi |
|
) |
|
cos = torch.cos(angle).to(dtype=dtype) |
|
sin = torch.sin(angle).to(dtype=dtype) |
|
if causal or local: |
|
q_ro = apply_rotary_emb( |
|
q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved |
|
) |
|
else: |
|
q_ro = rearrange( |
|
apply_rotary_emb( |
|
rearrange(q, "b s h d -> b 1 (s h) d"), |
|
cos, |
|
sin, |
|
seqlen_offsets=cache_seqlens, |
|
interleaved=rotary_interleaved, |
|
), |
|
"b 1 (s h) d -> b s h d", |
|
s=seqlen_q, |
|
) |
|
|
|
k_ro = apply_rotary_emb( |
|
k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved |
|
) |
|
else: |
|
cos, sin = None, None |
|
q_ro, k_ro = q, k |
|
|
|
k_cache_ref = ( |
|
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] |
|
).clone() |
|
v_cache_ref = ( |
|
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] |
|
).clone() |
|
if new_kv: |
|
update_mask = torch.logical_and( |
|
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new |
|
) |
|
k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") |
|
v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") |
|
k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
|
v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
|
out = flash_attn_with_kvcache( |
|
q, |
|
k_cache if paged_kv_block_size is None else k_cache_paged, |
|
v_cache if paged_kv_block_size is None else v_cache_paged, |
|
k, |
|
v, |
|
rotary_cos=cos, |
|
rotary_sin=sin, |
|
cache_seqlens=cache_seqlens, |
|
cache_batch_idx=cache_batch_idx, |
|
cache_leftpad=cache_leftpad, |
|
block_table=block_table, |
|
causal=causal, |
|
window_size=window_size, |
|
rotary_interleaved=rotary_interleaved, |
|
alibi_slopes=alibi_slopes, |
|
num_splits=num_splits, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out_ref, _ = attention_ref( |
|
q_ro, |
|
k_cache_rep, |
|
v_cache_rep, |
|
None, |
|
key_padding_mask, |
|
attn_bias, |
|
0.0, |
|
None, |
|
causal=causal, |
|
window_size=window_size, |
|
key_leftpad=cache_leftpad, |
|
) |
|
out_pt, _ = attention_ref( |
|
q_ro, |
|
k_cache_rep, |
|
v_cache_rep, |
|
None, |
|
key_padding_mask, |
|
attn_bias, |
|
0.0, |
|
None, |
|
causal=causal, |
|
window_size=window_size, |
|
upcast=False, |
|
reorder_ops=True, |
|
key_leftpad=cache_leftpad, |
|
) |
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
|
|
|
|
|
|
|
if new_kv: |
|
if paged_kv_block_size is None: |
|
k_cache_select = ( |
|
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] |
|
) |
|
v_cache_select = ( |
|
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] |
|
) |
|
else: |
|
k_cache_select = rearrange( |
|
k_cache_paged[block_table.to(dtype=torch.long).flatten()], |
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...", |
|
b=batch_size, |
|
)[:, :seqlen_k] |
|
v_cache_select = rearrange( |
|
v_cache_paged[block_table.to(dtype=torch.long).flatten()], |
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...", |
|
b=batch_size, |
|
)[:, :seqlen_k] |
|
assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) |
|
assert torch.equal(v_cache_select, v_cache_ref) |
|
mult = 3 if not alibi else 5 |
|
assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 |
|
|
|
|
|
def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): |
|
num_blocks = math.ceil(seqlen_k / paged_kv_block_size) * batch_size * 3 |
|
k_cache_paged = torch.randn( |
|
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype |
|
) |
|
v_cache_paged = torch.randn( |
|
num_blocks, paged_kv_block_size, nheads_k, d, device=device, dtype=dtype |
|
) |
|
block_table = rearrange( |
|
torch.randperm(num_blocks, dtype=torch.int32, device=device), |
|
"(b nblocks) -> b nblocks", |
|
b=batch_size, |
|
) |
|
k_cache = rearrange( |
|
|
|
k_cache_paged[block_table.to(dtype=torch.long).flatten()], |
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...", |
|
b=batch_size, |
|
)[:, :seqlen_k] |
|
v_cache = rearrange( |
|
v_cache_paged[block_table.to(dtype=torch.long).flatten()], |
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...", |
|
b=batch_size, |
|
)[:, :seqlen_k] |
|
return k_cache, v_cache, block_table, k_cache_paged, v_cache_paged, num_blocks |
|
|
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16]) |
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(1, 239), |
|
(239, 1), |
|
(3, 799), |
|
(799, 3), |
|
(1024, 128), |
|
(97, 97), |
|
(128, 128), |
|
(200, 200), |
|
(256, 256), |
|
(257, 257), |
|
(384, 384), |
|
(512, 512), |
|
(768, 768), |
|
(1024, 1024), |
|
], |
|
) |
|
@pytest.mark.parametrize("dropout_p", [0.0, 0.17]) |
|
|
|
def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dtype): |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 60 |
|
nheads = 4 |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
torch.random.manual_seed(42) |
|
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) |
|
g = torch.randn_like(out0) |
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
( |
|
dq0, |
|
dk0, |
|
dv0, |
|
) = torch.autograd.grad(out0, (q, k, v), g) |
|
|
|
dq_atol = 2 * ((dq0 + 0.3 - 0.3) - dq0).abs().max().item() |
|
|
|
for i in range(250): |
|
torch.random.manual_seed(42) |
|
out, lse, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True) |
|
assert torch.equal(out, out0) |
|
assert torch.equal(lse, lse0) |
|
|
|
if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90): |
|
( |
|
dq, |
|
dk, |
|
dv, |
|
) = torch.autograd.grad(out, (q, k, v), g) |
|
dq_equal = torch.allclose(dq, dq0, atol=dq_atol) |
|
if not dq_equal: |
|
print(f"Iter {i}, {dq_atol = }, dQ max diff: {(dq - dq0).abs().max().item()}") |
|
assert torch.equal(dv, dv0) |
|
assert torch.equal(dk, dk0) |
|
assert dq_equal |
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16]) |
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [16, 32, 64]) |
|
|
|
@pytest.mark.parametrize("seqlen", [1, 2, 5, 17, 128]) |
|
|
|
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype): |
|
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, |
|
in the case where seqlen % 128 != 0. |
|
""" |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 2 |
|
nheads = 5 |
|
q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5 |
|
k, v = [ |
|
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 |
|
for _ in range(2) |
|
] |
|
q.requires_grad_(True) |
|
k.requires_grad_(True) |
|
v.requires_grad_(True) |
|
out = flash_attn_func(q, k, v, causal=causal) |
|
g = torch.randn_like(out) |
|
out.backward(g) |
|
q_pt = q.detach().clone().requires_grad_(True) |
|
k_pt = k.detach().clone().requires_grad_(True) |
|
v_pt = v.detach().clone().requires_grad_(True) |
|
out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) |
|
out_pt.backward(g) |
|
q_ref = q.detach().clone().requires_grad_(True) |
|
k_ref = k.detach().clone().requires_grad_(True) |
|
v_ref = v.detach().clone().requires_grad_(True) |
|
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) |
|
out_ref.backward(g) |
|
print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") |
|
print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") |
|
print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") |
|
print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") |
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
|
assert (q.grad - q_ref.grad).abs().max().item() <= 5 * ( |
|
q_pt.grad - q_ref.grad |
|
).abs().max().item() + 1e-3 |
|
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * ( |
|
k_pt.grad - k_ref.grad |
|
).abs().max().item() + 1e-3 |
|
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * ( |
|
v_pt.grad - v_ref.grad |
|
).abs().max().item() + 1e-3 |
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [64, 128]) |
|
|
|
@pytest.mark.parametrize("seqlen", [97, 128, 200, 256]) |
|
|
|
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype): |
|
"""We previously had a bug where we were using the wrong strides of dout, which shows up |
|
when dout is not contiguous. |
|
""" |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 5 |
|
nheads = 2 |
|
q, k, v = [ |
|
torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda", requires_grad=True) |
|
for _ in range(3) |
|
] |
|
out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...") |
|
|
|
g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2] |
|
out.backward(g) |
|
q_pt = q.detach().clone().requires_grad_(True) |
|
k_pt = k.detach().clone().requires_grad_(True) |
|
v_pt = v.detach().clone().requires_grad_(True) |
|
out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True) |
|
out_pt = rearrange(out_pt, "b s ... -> s b ...") |
|
out_pt.backward(g) |
|
q_ref = q.detach().clone().requires_grad_(True) |
|
k_ref = k.detach().clone().requires_grad_(True) |
|
v_ref = v.detach().clone().requires_grad_(True) |
|
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal) |
|
out_ref = rearrange(out_ref, "b s ... -> s b ...") |
|
out_ref.backward(g) |
|
print(f"dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") |
|
print(f"dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") |
|
print(f"dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") |
|
print(f"dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}") |
|
print(f"dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}") |
|
print(f"dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}") |
|
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() |
|
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * ( |
|
q_pt.grad - q_ref.grad |
|
).abs().max().item() |
|
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * ( |
|
k_pt.grad - k_ref.grad |
|
).abs().max().item() |
|
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * ( |
|
v_pt.grad - v_ref.grad |
|
).abs().max().item() |
|
|
|
|
|
@pytest.mark.parametrize("dtype", [torch.float16]) |
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [16, 32, 64]) |
|
|
|
def test_flash_attn_bwd_varlen_overflow(d, causal, dtype): |
|
"""We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ, |
|
in the case where seqlen % 128 != 0 or varlen. |
|
""" |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
nheads = 5 |
|
q_cuseqlen = torch.tensor([0, 76, 110, 256], device=device, dtype=torch.int32) |
|
k_cuseqlen = torch.tensor([0, 1, 2, 3], device=device, dtype=torch.int32) |
|
Mq = 256 |
|
Mk = 3 |
|
|
|
q = torch.randn([Mq, nheads, d], dtype=dtype, device=device) * 3 |
|
k, v = [torch.randn([Mk, nheads, d], dtype=dtype, device=device) * 3 for _ in range(2)] |
|
q.requires_grad_(True) |
|
k.requires_grad_(True) |
|
v.requires_grad_(True) |
|
|
|
out = flash_attn_varlen_func(q, k, v, q_cuseqlen, k_cuseqlen, Mq, Mk, causal=causal) |
|
g = torch.randn_like(out) |
|
out.backward(g) |
|
|
|
assert not q.grad.isnan().any() |
|
assert not k.grad.isnan().any() |
|
assert not v.grad.isnan().any() |
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("swap_sq_sk", [False, True]) |
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(1, 239), |
|
(3, 799), |
|
(127, 512), |
|
(127, 513), |
|
(113, 203), |
|
(128, 217), |
|
(113, 211), |
|
(108, 256), |
|
(256, 512), |
|
(1023, 1024), |
|
], |
|
) |
|
|
|
def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): |
|
if ( |
|
max(seqlen_q, seqlen_k) >= 2048 |
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
|
): |
|
pytest.skip() |
|
if swap_sq_sk: |
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 4 |
|
nheads = 9 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True) |
|
|
|
g = torch.randn_like(out) |
|
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) |
|
for _ in range(50): |
|
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True) |
|
assert torch.equal(dv, dv0) |
|
assert torch.equal(dk, dk0) |
|
assert torch.equal(dq, dq0) |
|
|
|
|
|
@pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
|
|
|
@pytest.mark.parametrize("local", [False, True]) |
|
|
|
@pytest.mark.parametrize("causal", [False, True]) |
|
|
|
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("swap_sq_sk", [False, True]) |
|
|
|
@pytest.mark.parametrize( |
|
"seqlen_q,seqlen_k", |
|
[ |
|
(1, 239), |
|
(3, 799), |
|
(127, 512), |
|
(127, 513), |
|
(113, 203), |
|
(128, 217), |
|
(113, 211), |
|
(108, 256), |
|
(256, 512), |
|
(1023, 1024), |
|
], |
|
) |
|
|
|
def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, local, dtype): |
|
if ( |
|
max(seqlen_q, seqlen_k) >= 2048 |
|
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 |
|
): |
|
pytest.skip() |
|
if swap_sq_sk: |
|
seqlen_q, seqlen_k = seqlen_k, seqlen_q |
|
device = "cuda" |
|
|
|
torch.random.manual_seed(0) |
|
batch_size = 2 |
|
nheads = 9 |
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
|
q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
k = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
v = torch.randn(batch_size, seqlen_k, nheads, d, device=device, dtype=dtype, requires_grad=True) |
|
query_padding_mask = generate_random_padding_mask(seqlen_q, batch_size, device, mode="random") |
|
key_padding_mask = generate_random_padding_mask(seqlen_k, batch_size, device, mode="random") |
|
( |
|
q_unpad, |
|
k_unpad, |
|
v_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
q, |
|
k, |
|
v, |
|
output_pad_fn, |
|
dq_pad_fn, |
|
dk_pad_fn, |
|
) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) |
|
out = flash_attn_varlen_func( |
|
q_unpad, |
|
k_unpad, |
|
v_unpad, |
|
cu_seqlens_q, |
|
cu_seqlens_k, |
|
max_seqlen_q, |
|
max_seqlen_k, |
|
0.0, |
|
causal=causal, |
|
window_size=window_size, |
|
deterministic=True, |
|
) |
|
|
|
g = torch.randn_like(out) |
|
dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) |
|
for _ in range(50): |
|
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) |
|
assert torch.equal(dv, dv0) |
|
assert torch.equal(dk, dk0) |
|
assert torch.equal(dq, dq0) |
|
|