kernel
File size: 6,421 Bytes
876ac68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import torch
import flash_attn

# make reproducible
torch.manual_seed(0)


def _attention_torch(query, key, value, *, backend):
    query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
    with torch.nn.attention.sdpa_kernel(backend):
        out = torch.nn.functional.scaled_dot_product_attention(query, key, value)
    out = out.transpose(1, 2).contiguous()
    return out


def test_flash_attn():
    """Test standard flash attention with mha_fwd"""
    print("===== Testing mha_fwd =====")

    batch_size = 1
    seq_len = 4224
    num_attention_heads = 24
    attention_head_dim = 128

    shape = (batch_size, seq_len, num_attention_heads, attention_head_dim)

    print(f"Testing shape: {shape}")
    print(f"Batch size: {batch_size}, Seq len: {seq_len}")
    print(f"Num heads: {num_attention_heads}, Head dim: {attention_head_dim}")

    query = torch.randn(shape, device="cuda", dtype=torch.float16)
    key = torch.randn(shape, device="cuda", dtype=torch.float16)
    value = torch.randn(shape, device="cuda", dtype=torch.float16)

    # Get reference implementation using PyTorch SDPA
    golden_truth = _attention_torch(
        query, key, value, backend=torch.nn.attention.SDPBackend.MATH
    )

    print(f"Golden truth shape: {golden_truth.shape}")
    print(f"Query sum: {query.sum().item()}")

    # Test non-causal flash attention
    out, softmax_lse, p, rng_state = flash_attn.fwd(
        q=query,
        k=key,
        v=value,
        is_causal=False,
    )

    print(f"Flash attention output shape: {out.shape}")
    print(f"Query sum after attention: {query.sum().item()}")

    # Compare outputs
    diff = (out - golden_truth).abs().max()
    print(f"Max absolute difference (non-causal): {diff.item()}")

    assert out.shape == shape
    assert diff < 1e-2, f"Difference too large: {diff.item()}"

    # Test causal attention
    print("\n--- Testing with causal=True ---")
    out_causal, _, _, _ = flash_attn.fwd(
        q=query,
        k=key,
        v=value,
        is_causal=True,
    )

    print(f"Causal attention output shape: {out_causal.shape}")
    assert out_causal.shape == shape

    # Compare causal vs non-causal (should be different)
    diff_causal = (out - out_causal).abs().max()
    print(f"Difference between causal and non-causal: {diff_causal.item()}")
    assert diff_causal > 1e-3, "Causal and non-causal should produce different results"

    print("✓ mha_fwd test passed!")


def test_mha_varlen_fwd():
    """Test variable-length sequences with mha_varlen_fwd"""
    print("\n===== Testing mha_varlen_fwd =====")

    # Create variable length sequences
    # Batch with 3 sequences of lengths: 512, 1024, 256
    seq_lens = [512, 1024, 256]
    total_seq_len = sum(seq_lens)
    num_attention_heads = 16
    attention_head_dim = 64

    # Create cumulative sequence lengths (required for varlen)
    cu_seqlens = torch.tensor(
        [0] + [sum(seq_lens[: i + 1]) for i in range(len(seq_lens))],
        device="cuda",
        dtype=torch.int32,
    )

    print(f"Sequence lengths: {seq_lens}")
    print(f"Cumulative sequence lengths: {cu_seqlens}")
    print(f"Total sequence length: {total_seq_len}")

    # Create packed tensors (all sequences concatenated)
    query = torch.randn(
        total_seq_len,
        num_attention_heads,
        attention_head_dim,
        device="cuda",
        dtype=torch.float16,
    )
    key = torch.randn(
        total_seq_len,
        num_attention_heads,
        attention_head_dim,
        device="cuda",
        dtype=torch.float16,
    )
    value = torch.randn(
        total_seq_len,
        num_attention_heads,
        attention_head_dim,
        device="cuda",
        dtype=torch.float16,
    )

    print(f"Query shape: {query.shape}")
    print(f"Key shape: {key.shape}")
    print(f"Value shape: {value.shape}")

    # Create reference truth by running attention on individual sequences
    # and concatenating the results
    golden_truth_parts = []
    for i, seq_len in enumerate(seq_lens):
        start_idx = cu_seqlens[i]
        end_idx = cu_seqlens[i + 1]

        # Extract individual sequence
        q_seq = query[start_idx:end_idx].unsqueeze(0)  # Add batch dimension
        k_seq = key[start_idx:end_idx].unsqueeze(0)
        v_seq = value[start_idx:end_idx].unsqueeze(0)

        # Run reference attention on this sequence
        golden_seq = _attention_torch(
            q_seq, k_seq, v_seq, backend=torch.nn.attention.SDPBackend.MATH
        )
        golden_truth_parts.append(golden_seq.squeeze(0))  # Remove batch dimension

    # Concatenate all sequences back together
    golden_truth = torch.cat(golden_truth_parts, dim=0)
    print(f"Golden truth shape: {golden_truth.shape}")

    # Run flash attention varlen
    out, softmax_lse, p, rng_state = flash_attn.varlen_fwd(
        q=query,
        k=key,
        v=value,
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_k=cu_seqlens,
        max_seqlen_q=max(seq_lens),
        max_seqlen_k=max(seq_lens),
        is_causal=False,
    )

    print(f"Flash attention varlen output shape: {out.shape}")
    print(f"Output should match input: {out.shape == query.shape}")

    # Compare with reference truth
    diff = (out - golden_truth).abs().max()
    print(f"Max absolute difference (non-causal): {diff.item()}")

    # Verify output shape
    assert out.shape == (total_seq_len, num_attention_heads, attention_head_dim)
    assert diff < 1e-2, f"Difference too large: {diff.item()}"

    # Test with causal attention
    print("\n--- Testing with causal=True ---")
    out_causal, _, _, _ = flash_attn.varlen_fwd(
        q=query,
        k=key,
        v=value,
        cu_seqlens_q=cu_seqlens,
        cu_seqlens_k=cu_seqlens,
        max_seqlen_q=max(seq_lens),
        max_seqlen_k=max(seq_lens),
        is_causal=True,
    )

    print(f"Causal attention output shape: {out_causal.shape}")
    assert out_causal.shape == (total_seq_len, num_attention_heads, attention_head_dim)

    # The causal and non-causal outputs should be different
    diff_causal = (out - out_causal).abs().max()
    print(f"Difference between causal and non-causal: {diff_causal.item()}")
    assert diff_causal > 1e-3, "Causal and non-causal should produce different results"

    print("✓ mha_varlen_fwd test passed!")


if __name__ == "__main__":
    test_flash_attn()
    test_mha_varlen_fwd()