Add test for flash_attn_varlen_func kernel (#5484)
This commit is contained in:
@@ -296,6 +296,152 @@ def attention_ref(
|
||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||
|
||||
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
add_unused_qkv=False,
|
||||
query_unused_mask=None,
|
||||
key_unused_mask=None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
if query_unused_mask is not None or key_unused_mask is not None:
|
||||
assert not kvpacked
|
||||
assert not qkvpacked
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
||||
q,
|
||||
query_padding_mask,
|
||||
query_unused_mask,
|
||||
)
|
||||
output_pad_fn = lambda output_unpad: pad_input(
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
seqused_q = None
|
||||
max_seqlen_q = seqlen_q
|
||||
output_pad_fn = lambda output_unpad: rearrange(
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
|
||||
k, key_padding_mask, key_unused_mask
|
||||
)
|
||||
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
seqused_k = None
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert (query_padding_mask == key_padding_mask).all()
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
if query_padding_mask is not None:
|
||||
dqkv_pad_fn = lambda dqkv_unpad: pad_input(
|
||||
dqkv_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
else:
|
||||
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
|
||||
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
||||
)
|
||||
return (
|
||||
qkv_unpad.detach().requires_grad_(),
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
qkv.detach().requires_grad_(),
|
||||
output_pad_fn,
|
||||
dqkv_pad_fn,
|
||||
)
|
||||
elif kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
dq_pad_fn = output_pad_fn
|
||||
if key_padding_mask is not None:
|
||||
dkv_pad_fn = lambda dkv_unpad: pad_input(
|
||||
dkv_unpad, indices_k, batch_size, seqlen_k
|
||||
)
|
||||
else:
|
||||
dkv_pad_fn = lambda dkv_unpad: rearrange(
|
||||
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
||||
)
|
||||
return (
|
||||
q_unpad.detach().requires_grad_(),
|
||||
kv_unpad.detach().requires_grad_(),
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q.detach().requires_grad_(),
|
||||
kv.detach().requires_grad_(),
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dkv_pad_fn,
|
||||
)
|
||||
else:
|
||||
dq_pad_fn = output_pad_fn
|
||||
if key_padding_mask is not None:
|
||||
dk_pad_fn = lambda dk_unpad: pad_input(
|
||||
dk_unpad, indices_k, batch_size, seqlen_k
|
||||
)
|
||||
else:
|
||||
dk_pad_fn = lambda dk_unpad: rearrange(
|
||||
dk_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
return (
|
||||
q_unpad.detach().requires_grad_(),
|
||||
k_unpad.detach().requires_grad_(),
|
||||
v_unpad.detach().requires_grad_(),
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q.detach().requires_grad_(),
|
||||
k.detach().requires_grad_(),
|
||||
v.detach().requires_grad_(),
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
|
||||
@@ -855,5 +1001,320 @@ def _generate_block_kvcache(
|
||||
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
@pytest.mark.parametrize("mha_type", ["mha"])
|
||||
# @pytest.mark.parametrize("has_qv", [False, True])
|
||||
@pytest.mark.parametrize("has_qv", [False])
|
||||
# @pytest.mark.parametrize("deterministic", [False, True])
|
||||
@pytest.mark.parametrize("deterministic", [False])
|
||||
@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))
|
||||
# @pytest.mark.parametrize("softcap", [0.0])
|
||||
@pytest.mark.parametrize("local", [False])
|
||||
# @pytest.mark.parametrize("local", [False])
|
||||
@pytest.mark.parametrize("causal", [False, True])
|
||||
# @pytest.mark.parametrize("causal", [False])
|
||||
@pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||
# @pytest.mark.parametrize("add_unused_qkv", [True])
|
||||
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
|
||||
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||
# @pytest.mark.parametrize('d', [56, 80])
|
||||
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
|
||||
# @pytest.mark.parametrize("d", [64, 96, 128])
|
||||
# @pytest.mark.parametrize("d", COMPILED_HDIMS)
|
||||
@pytest.mark.parametrize("d", [128])
|
||||
@pytest.mark.parametrize(
|
||||
"seqlen_q,seqlen_k",
|
||||
[
|
||||
(1, 1),
|
||||
(1, 3),
|
||||
(2, 1),
|
||||
(511, 1),
|
||||
(3, 513),
|
||||
(64, 128),
|
||||
(128, 128),
|
||||
(256, 256),
|
||||
(113, 203),
|
||||
(128, 217),
|
||||
(113, 211),
|
||||
(108, 256),
|
||||
(256, 512),
|
||||
(307, 256),
|
||||
(640, 128),
|
||||
(512, 256),
|
||||
(1024, 1024),
|
||||
(1023, 1024),
|
||||
(1024, 1023),
|
||||
(2048, 2048),
|
||||
],
|
||||
)
|
||||
def test_flash_attn_varlen_output(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
d,
|
||||
add_unused_qkv,
|
||||
causal,
|
||||
local,
|
||||
softcap,
|
||||
deterministic,
|
||||
has_qv,
|
||||
mha_type,
|
||||
dtype,
|
||||
):
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
|
||||
device = "cuda"
|
||||
# set seed
|
||||
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
|
||||
# batch_size = 40
|
||||
# nheads = 16
|
||||
batch_size = 9 if seqlen_q <= 2048 else 2
|
||||
nheads = 6
|
||||
# batch_size = 2
|
||||
# nheads = 1
|
||||
nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
dv_vals = [d]
|
||||
for dv in dv_vals:
|
||||
q_ref = torch.randn(
|
||||
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
|
||||
)
|
||||
if softcap > 0.0:
|
||||
# Ensure the values of qk are at least within softcap range.
|
||||
q_ref = (q_ref * softcap / 4).detach().requires_grad_()
|
||||
q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
|
||||
k_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
.requires_grad_()
|
||||
)
|
||||
v_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
.requires_grad_()
|
||||
)
|
||||
if has_qv:
|
||||
qv_ref = (
|
||||
torch.randn(
|
||||
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
|
||||
)
|
||||
.to(dtype)
|
||||
.to(dtype_ref)
|
||||
)
|
||||
else:
|
||||
qv_ref = None
|
||||
# Put window_size after QKV randn so that window_size changes from test to test
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
q_descale, k_descale, v_descale = [
|
||||
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
||||
* 2
|
||||
for _ in range(3)
|
||||
]
|
||||
else:
|
||||
q_descale, k_descale, v_descale = None, None, None
|
||||
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
|
||||
qv = qv_ref.detach() if has_qv else None
|
||||
query_padding_mask = generate_random_padding_mask(
|
||||
seqlen_q, batch_size, device, mode="random", zero_lengths=False
|
||||
)
|
||||
key_padding_mask = generate_random_padding_mask(
|
||||
seqlen_k, batch_size, device, mode="random", zero_lengths=True
|
||||
)
|
||||
|
||||
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
|
||||
if add_unused:
|
||||
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
|
||||
attn_mask = torch.logical_and(padding_mask, another_mask)
|
||||
unused_mask = torch.logical_xor(
|
||||
torch.logical_or(padding_mask, another_mask), attn_mask
|
||||
)
|
||||
else:
|
||||
attn_mask = padding_mask
|
||||
unused_mask = None
|
||||
return attn_mask, unused_mask
|
||||
|
||||
query_padding_mask, query_unused_mask = _gen_unused_masks(
|
||||
query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
|
||||
)
|
||||
key_padding_mask, key_unused_mask = _gen_unused_masks(
|
||||
key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
|
||||
)
|
||||
|
||||
(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seqused_q,
|
||||
seqused_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
dq_pad_fn,
|
||||
dk_pad_fn,
|
||||
) = generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
kvpacked=False,
|
||||
query_unused_mask=query_unused_mask,
|
||||
key_unused_mask=key_unused_mask,
|
||||
)
|
||||
q_unpad, k_unpad, v_unpad = [
|
||||
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
|
||||
]
|
||||
out_ref, attn_ref = attention_ref(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
causal=causal,
|
||||
qv=qv_ref,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q_ref,
|
||||
k_ref,
|
||||
v_ref,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
causal=causal,
|
||||
qv=qv_ref,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||
)
|
||||
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||
|
||||
if query_unused_mask is not None:
|
||||
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
|
||||
|
||||
# Numerical error if we just do any arithmetic on out_ref
|
||||
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
|
||||
rtol = 2 if softcap == 0.0 else 3
|
||||
|
||||
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
||||
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
||||
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
||||
out_unpad, lse, *rest = flash_attn_varlen_func(
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
seqused_q=seqused_q,
|
||||
seqused_k=seqused_k,
|
||||
causal=causal,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if query_unused_mask is not None:
|
||||
out.masked_fill_(q_zero_masking, 0.0)
|
||||
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||
|
||||
# Check that FlashAttention's numerical error is at most 3x the numerical error
|
||||
# of a Pytorch implementation.
|
||||
assert (out - out_ref).abs().max().item() <= rtol * (
|
||||
out_pt - out_ref
|
||||
).abs().max().item() + fwd_atol
|
||||
|
||||
if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv:
|
||||
g_unpad = torch.randn_like(out_unpad)
|
||||
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
|
||||
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
|
||||
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
|
||||
)
|
||||
dq = dq_pad_fn(dq_unpad)
|
||||
dk = dk_pad_fn(dk_unpad)
|
||||
dv = dk_pad_fn(dv_unpad)
|
||||
if key_unused_mask is not None:
|
||||
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
|
||||
dk.masked_fill_(k_zero_masking, 0.0)
|
||||
dv.masked_fill_(k_zero_masking, 0.0)
|
||||
if query_unused_mask is not None:
|
||||
dq.masked_fill_(q_zero_masking, 0.0)
|
||||
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
|
||||
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
|
||||
# assert dq_accum.abs().max().item() == 0.0
|
||||
g = output_pad_fn(g_unpad)
|
||||
|
||||
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
|
||||
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)
|
||||
dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
|
||||
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||
|
||||
if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv:
|
||||
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dq - dq_ref).abs().max().item() <= rtol * (
|
||||
dq_pt - dq_ref
|
||||
).abs().max().item() + dq_atol
|
||||
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dk - dk_ref).abs().max().item() <= rtol * (
|
||||
dk_pt - dk_ref
|
||||
).abs().max().item() + dk_atol
|
||||
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (
|
||||
0 if softcap == 0 else 3e-4
|
||||
)
|
||||
assert (dv - dv_ref).abs().max().item() <= rtol * (
|
||||
dv_pt - dv_ref
|
||||
).abs().max().item() + dv_atol
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
Reference in New Issue
Block a user