File size: 6,421 Bytes
876ac68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import torch
import flash_attn
# make reproducible
torch.manual_seed(0)
def _attention_torch(query, key, value, *, backend):
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
with torch.nn.attention.sdpa_kernel(backend):
out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
out = out.transpose(1, 2).contiguous()
return out
def test_flash_attn():
"""Test standard flash attention with mha_fwd"""
print("===== Testing mha_fwd =====")
batch_size = 1
seq_len = 4224
num_attention_heads = 24
attention_head_dim = 128
shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)
print(f"Testing shape: {shape}")
print(f"Batch size: {batch_size}, Seq len: {seq_len}")
print(f"Num heads: {num_attention_heads}, Head dim: {attention_head_dim}")
query = torch.randn(shape, device="cuda", dtype=torch.float16)
key = torch.randn(shape, device="cuda", dtype=torch.float16)
value = torch.randn(shape, device="cuda", dtype=torch.float16)
# Get reference implementation using PyTorch SDPA
golden_truth = _attention_torch(
query, key, value, backend=torch.nn.attention.SDPBackend.MATH
)
print(f"Golden truth shape: {golden_truth.shape}")
print(f"Query sum: {query.sum().item()}")
# Test non-causal flash attention
out, softmax_lse, p, rng_state = flash_attn.fwd(
q=query,
k=key,
v=value,
is_causal=False,
)
print(f"Flash attention output shape: {out.shape}")
print(f"Query sum after attention: {query.sum().item()}")
# Compare outputs
diff = (out - golden_truth).abs().max()
print(f"Max absolute difference (non-causal): {diff.item()}")
assert out.shape == shape
assert diff < 1e-2, f"Difference too large: {diff.item()}"
# Test causal attention
print("\n--- Testing with causal=True ---")
out_causal, _, _, _ = flash_attn.fwd(
q=query,
k=key,
v=value,
is_causal=True,
)
print(f"Causal attention output shape: {out_causal.shape}")
assert out_causal.shape == shape
# Compare causal vs non-causal (should be different)
diff_causal = (out - out_causal).abs().max()
print(f"Difference between causal and non-causal: {diff_causal.item()}")
assert diff_causal > 1e-3, "Causal and non-causal should produce different results"
print("✓ mha_fwd test passed!")
def test_mha_varlen_fwd():
"""Test variable-length sequences with mha_varlen_fwd"""
print("\n===== Testing mha_varlen_fwd =====")
# Create variable length sequences
# Batch with 3 sequences of lengths: 512, 1024, 256
seq_lens = [512, 1024, 256]
total_seq_len = sum(seq_lens)
num_attention_heads = 16
attention_head_dim = 64
# Create cumulative sequence lengths (required for varlen)
cu_seqlens = torch.tensor(
[0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))],
device="cuda",
dtype=torch.int32,
)
print(f"Sequence lengths: {seq_lens}")
print(f"Cumulative sequence lengths: {cu_seqlens}")
print(f"Total sequence length: {total_seq_len}")
# Create packed tensors (all sequences concatenated)
query = torch.randn(
total_seq_len,
num_attention_heads,
attention_head_dim,
device="cuda",
dtype=torch.float16,
)
key = torch.randn(
total_seq_len,
num_attention_heads,
attention_head_dim,
device="cuda",
dtype=torch.float16,
)
value = torch.randn(
total_seq_len,
num_attention_heads,
attention_head_dim,
device="cuda",
dtype=torch.float16,
)
print(f"Query shape: {query.shape}")
print(f"Key shape: {key.shape}")
print(f"Value shape: {value.shape}")
# Create reference truth by running attention on individual sequences
# and concatenating the results
golden_truth_parts = []
for i, seq_len in enumerate(seq_lens):
start_idx = cu_seqlens[i]
end_idx = cu_seqlens[i + 1]
# Extract individual sequence
q_seq = query[start_idx:end_idx].unsqueeze(0) # Add batch dimension
k_seq = key[start_idx:end_idx].unsqueeze(0)
v_seq = value[start_idx:end_idx].unsqueeze(0)
# Run reference attention on this sequence
golden_seq = _attention_torch(
q_seq, k_seq, v_seq, backend=torch.nn.attention.SDPBackend.MATH
)
golden_truth_parts.append(golden_seq.squeeze(0)) # Remove batch dimension
# Concatenate all sequences back together
golden_truth = torch.cat(golden_truth_parts, dim=0)
print(f"Golden truth shape: {golden_truth.shape}")
# Run flash attention varlen
out, softmax_lse, p, rng_state = flash_attn.varlen_fwd(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max(seq_lens),
max_seqlen_k=max(seq_lens),
is_causal=False,
)
print(f"Flash attention varlen output shape: {out.shape}")
print(f"Output should match input: {out.shape == query.shape}")
# Compare with reference truth
diff = (out - golden_truth).abs().max()
print(f"Max absolute difference (non-causal): {diff.item()}")
# Verify output shape
assert out.shape == (total_seq_len, num_attention_heads, attention_head_dim)
assert diff < 1e-2, f"Difference too large: {diff.item()}"
# Test with causal attention
print("\n--- Testing with causal=True ---")
out_causal, _, _, _ = flash_attn.varlen_fwd(
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=max(seq_lens),
max_seqlen_k=max(seq_lens),
is_causal=True,
)
print(f"Causal attention output shape: {out_causal.shape}")
assert out_causal.shape == (total_seq_len, num_attention_heads, attention_head_dim)
# The causal and non-causal outputs should be different
diff_causal = (out - out_causal).abs().max()
print(f"Difference between causal and non-causal: {diff_causal.item()}")
assert diff_causal > 1e-3, "Causal and non-causal should produce different results"
print("✓ mha_varlen_fwd test passed!")
if __name__ == "__main__":
test_flash_attn()
test_mha_varlen_fwd()
|