diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index 402dc81d7..ae061b1bf 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -296,6 +296,152 @@ def attention_ref( return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) +def generate_qkv( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + kvpacked=False, + qkvpacked=False, + add_unused_qkv=False, + query_unused_mask=None, + key_unused_mask=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + if query_unused_mask is not None or key_unused_mask is not None: + assert not kvpacked + assert not qkvpacked + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input( + q, + query_padding_mask, + query_unused_mask, + ) + output_pad_fn = lambda output_unpad: pad_input( + output_unpad, indices_q, batch_size, seqlen_q + ) + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, + (batch_size + 1) * seqlen_q, + step=seqlen_q, + dtype=torch.int32, + device=q_unpad.device, + ) + seqused_q = None + max_seqlen_q = seqlen_q + output_pad_fn = lambda output_unpad: rearrange( + output_unpad, "(b s) h d -> b s h d", b=batch_size + ) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input( + k, key_padding_mask, key_unused_mask + ) + v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, + (batch_size + 1) * seqlen_k, + step=seqlen_k, + dtype=torch.int32, + device=k_unpad.device, + ) + seqused_k = None + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + dqkv_pad_fn = lambda dqkv_unpad: pad_input( + dqkv_unpad, indices_q, batch_size, seqlen_q + ) + else: + dqkv_pad_fn = lambda dqkv_unpad: rearrange( + dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dkv_pad_fn = lambda dkv_unpad: pad_input( + dkv_unpad, indices_k, batch_size, seqlen_k + ) + else: + dkv_pad_fn = lambda dkv_unpad: rearrange( + dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + dk_pad_fn = lambda dk_unpad: pad_input( + dk_unpad, indices_k, batch_size, seqlen_k + ) + else: + dk_pad_fn = lambda dk_unpad: rearrange( + dk_unpad, "(b s) h d -> b s h d", b=batch_size + ) + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + @pytest.mark.skipif( not is_fa3_supported(), reason="flash_attn at sgl-kernel is only supported on sm90 and above", @@ -855,5 +1001,320 @@ def _generate_block_kvcache( return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks +# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) +@pytest.mark.parametrize( + "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) +) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) +# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +@pytest.mark.parametrize("mha_type", ["mha"]) +# @pytest.mark.parametrize("has_qv", [False, True]) +@pytest.mark.parametrize("has_qv", [False]) +# @pytest.mark.parametrize("deterministic", [False, True]) +@pytest.mark.parametrize("deterministic", [False]) +@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else [])) +# @pytest.mark.parametrize("softcap", [0.0]) +@pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [False]) +@pytest.mark.parametrize("add_unused_qkv", [False, True]) +# @pytest.mark.parametrize("add_unused_qkv", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128]) +# @pytest.mark.parametrize("d", [64, 96, 128]) +# @pytest.mark.parametrize("d", COMPILED_HDIMS) +@pytest.mark.parametrize("d", [128]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (1, 1), + (1, 3), + (2, 1), + (511, 1), + (3, 513), + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (307, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +def test_flash_attn_varlen_output( + seqlen_q, + seqlen_k, + d, + add_unused_qkv, + causal, + local, + softcap, + deterministic, + has_qv, + mha_type, + dtype, +): + from sgl_kernel.flash_attn import flash_attn_varlen_func + + device = "cuda" + # set seed + torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local)) + # batch_size = 40 + # nheads = 16 + batch_size = 9 if seqlen_q <= 2048 else 2 + nheads = 6 + # batch_size = 2 + # nheads = 1 + nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype + dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + if dtype == torch.float8_e4m3fn: + dv_vals = [d] + for dv in dv_vals: + q_ref = torch.randn( + batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref + ) + if softcap > 0.0: + # Ensure the values of qk are at least within softcap range. + q_ref = (q_ref * softcap / 4).detach().requires_grad_() + q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_() + k_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + v_ref = ( + torch.randn( + batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + .requires_grad_() + ) + if has_qv: + qv_ref = ( + torch.randn( + batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref + ) + .to(dtype) + .to(dtype_ref) + ) + else: + qv_ref = None + # Put window_size after QKV randn so that window_size changes from test to test + window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + if dtype == torch.float8_e4m3fn: + q_descale, k_descale, v_descale = [ + torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) + * 2 + for _ in range(3) + ] + else: + q_descale, k_descale, v_descale = None, None, None + q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)] + qv = qv_ref.detach() if has_qv else None + query_padding_mask = generate_random_padding_mask( + seqlen_q, batch_size, device, mode="random", zero_lengths=False + ) + key_padding_mask = generate_random_padding_mask( + seqlen_k, batch_size, device, mode="random", zero_lengths=True + ) + + def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): + if add_unused: + another_mask = generate_random_padding_mask(max_seq_len, bs, device) + attn_mask = torch.logical_and(padding_mask, another_mask) + unused_mask = torch.logical_xor( + torch.logical_or(padding_mask, another_mask), attn_mask + ) + else: + attn_mask = padding_mask + unused_mask = None + return attn_mask, unused_mask + + query_padding_mask, query_unused_mask = _gen_unused_masks( + query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device + ) + key_padding_mask, key_unused_mask = _gen_unused_masks( + key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device + ) + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + seqused_q, + seqused_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, + k, + v, + query_padding_mask, + key_padding_mask, + kvpacked=False, + query_unused_mask=query_unused_mask, + key_unused_mask=key_unused_mask, + ) + q_unpad, k_unpad, v_unpad = [ + x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) + ] + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + ) + out_pt, attn_pt = attention_ref( + q_ref, + k_ref, + v_ref, + query_padding_mask, + key_padding_mask, + causal=causal, + qv=qv_ref, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + upcast=False, + reorder_ops=True, + intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + ) + + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + + if query_unused_mask is not None: + q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1") + + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 + + pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] + num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] + for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals): + out_unpad, lse, *rest = flash_attn_varlen_func( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + seqused_q=seqused_q, + seqused_k=seqused_k, + causal=causal, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + window_size=window_size, + softcap=softcap, + return_softmax_lse=True, + ) + out = output_pad_fn(out_unpad) + if query_unused_mask is not None: + out.masked_fill_(q_zero_masking, 0.0) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + + # Check that FlashAttention's numerical error is at most 3x the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= rtol * ( + out_pt - out_ref + ).abs().max().item() + fwd_atol + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + g_unpad = torch.randn_like(out_unpad) + do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2) + dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad( + out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad + ) + dq = dq_pad_fn(dq_unpad) + dk = dk_pad_fn(dk_unpad) + dv = dk_pad_fn(dv_unpad) + if key_unused_mask is not None: + k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1") + dk.masked_fill_(k_zero_masking, 0.0) + dv.masked_fill_(k_zero_masking, 0.0) + if query_unused_mask is not None: + dq.masked_fill_(q_zero_masking, 0.0) + # print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}") + # assert (softmax_d - do_o).abs().max().item() <= 1e-5 + # assert dq_accum.abs().max().item() == 0.0 + g = output_pad_fn(g_unpad) + + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g) + dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g) + print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv: + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dq - dq_ref).abs().max().item() <= rtol * ( + dq_pt - dq_ref + ).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dk - dk_ref).abs().max().item() <= rtol * ( + dk_pt - dk_ref + ).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + ( + 0 if softcap == 0 else 3e-4 + ) + assert (dv - dv_ref).abs().max().item() <= rtol * ( + dv_pt - dv_ref + ).abs().max().item() + dv_atol + + if __name__ == "__main__": pytest.main([__file__])