Add test for flash_attn_varlen_func kernel (#5484)
This commit is contained in:
@@ -296,6 +296,152 @@ def attention_ref(
|
|||||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
kvpacked=False,
|
||||||
|
qkvpacked=False,
|
||||||
|
add_unused_qkv=False,
|
||||||
|
query_unused_mask=None,
|
||||||
|
key_unused_mask=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Arguments:
|
||||||
|
q: (batch_size, seqlen_q, nheads, d)
|
||||||
|
k: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
v: (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
query_padding_mask: (batch_size, seqlen), bool
|
||||||
|
key_padding_mask: (batch_size, seqlen), bool
|
||||||
|
"""
|
||||||
|
assert not (kvpacked and qkvpacked)
|
||||||
|
batch_size, seqlen_q, nheads, d = q.shape
|
||||||
|
_, seqlen_k, nheads_k, _ = k.shape
|
||||||
|
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||||
|
if query_unused_mask is not None or key_unused_mask is not None:
|
||||||
|
assert not kvpacked
|
||||||
|
assert not qkvpacked
|
||||||
|
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, seqused_q = unpad_input(
|
||||||
|
q,
|
||||||
|
query_padding_mask,
|
||||||
|
query_unused_mask,
|
||||||
|
)
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input(
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_q,
|
||||||
|
step=seqlen_q,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=q_unpad.device,
|
||||||
|
)
|
||||||
|
seqused_q = None
|
||||||
|
max_seqlen_q = seqlen_q
|
||||||
|
output_pad_fn = lambda output_unpad: rearrange(
|
||||||
|
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
k_unpad, indices_k, cu_seqlens_k, max_seqlen_k, seqused_k = unpad_input(
|
||||||
|
k, key_padding_mask, key_unused_mask
|
||||||
|
)
|
||||||
|
v_unpad, _, _, _, _ = unpad_input(v, key_padding_mask, key_unused_mask)
|
||||||
|
else:
|
||||||
|
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||||
|
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||||
|
cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
(batch_size + 1) * seqlen_k,
|
||||||
|
step=seqlen_k,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=k_unpad.device,
|
||||||
|
)
|
||||||
|
seqused_k = None
|
||||||
|
max_seqlen_k = seqlen_k
|
||||||
|
|
||||||
|
if qkvpacked:
|
||||||
|
assert (query_padding_mask == key_padding_mask).all()
|
||||||
|
assert nheads == nheads_k
|
||||||
|
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||||
|
qkv = torch.stack([q, k, v], dim=2)
|
||||||
|
if query_padding_mask is not None:
|
||||||
|
dqkv_pad_fn = lambda dqkv_unpad: pad_input(
|
||||||
|
dqkv_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dqkv_pad_fn = lambda dqkv_unpad: rearrange(
|
||||||
|
dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
qkv_unpad.detach().requires_grad_(),
|
||||||
|
cu_seqlens_q,
|
||||||
|
max_seqlen_q,
|
||||||
|
qkv.detach().requires_grad_(),
|
||||||
|
output_pad_fn,
|
||||||
|
dqkv_pad_fn,
|
||||||
|
)
|
||||||
|
elif kvpacked:
|
||||||
|
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
dq_pad_fn = output_pad_fn
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
dkv_pad_fn = lambda dkv_unpad: pad_input(
|
||||||
|
dkv_unpad, indices_k, batch_size, seqlen_k
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dkv_pad_fn = lambda dkv_unpad: rearrange(
|
||||||
|
dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
q_unpad.detach().requires_grad_(),
|
||||||
|
kv_unpad.detach().requires_grad_(),
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q.detach().requires_grad_(),
|
||||||
|
kv.detach().requires_grad_(),
|
||||||
|
output_pad_fn,
|
||||||
|
dq_pad_fn,
|
||||||
|
dkv_pad_fn,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dq_pad_fn = output_pad_fn
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
dk_pad_fn = lambda dk_unpad: pad_input(
|
||||||
|
dk_unpad, indices_k, batch_size, seqlen_k
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dk_pad_fn = lambda dk_unpad: rearrange(
|
||||||
|
dk_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
q_unpad.detach().requires_grad_(),
|
||||||
|
k_unpad.detach().requires_grad_(),
|
||||||
|
v_unpad.detach().requires_grad_(),
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q.detach().requires_grad_(),
|
||||||
|
k.detach().requires_grad_(),
|
||||||
|
v.detach().requires_grad_(),
|
||||||
|
output_pad_fn,
|
||||||
|
dq_pad_fn,
|
||||||
|
dk_pad_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not is_fa3_supported(),
|
not is_fa3_supported(),
|
||||||
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
|
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
|
||||||
@@ -855,5 +1001,320 @@ def _generate_block_kvcache(
|
|||||||
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
|
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
|
||||||
|
|
||||||
|
|
||||||
|
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
|
||||||
|
)
|
||||||
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||||
|
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||||
|
@pytest.mark.parametrize("mha_type", ["mha"])
|
||||||
|
# @pytest.mark.parametrize("has_qv", [False, True])
|
||||||
|
@pytest.mark.parametrize("has_qv", [False])
|
||||||
|
# @pytest.mark.parametrize("deterministic", [False, True])
|
||||||
|
@pytest.mark.parametrize("deterministic", [False])
|
||||||
|
@pytest.mark.parametrize("softcap", [0.0] + ([15.0] if not DISABLE_SOFTCAP else []))
|
||||||
|
# @pytest.mark.parametrize("softcap", [0.0])
|
||||||
|
@pytest.mark.parametrize("local", [False])
|
||||||
|
# @pytest.mark.parametrize("local", [False])
|
||||||
|
@pytest.mark.parametrize("causal", [False, True])
|
||||||
|
# @pytest.mark.parametrize("causal", [False])
|
||||||
|
@pytest.mark.parametrize("add_unused_qkv", [False, True])
|
||||||
|
# @pytest.mark.parametrize("add_unused_qkv", [True])
|
||||||
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||||
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
|
||||||
|
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
|
||||||
|
# @pytest.mark.parametrize('d', [56, 80])
|
||||||
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
|
||||||
|
# @pytest.mark.parametrize("d", [64, 96, 128])
|
||||||
|
# @pytest.mark.parametrize("d", COMPILED_HDIMS)
|
||||||
|
@pytest.mark.parametrize("d", [128])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seqlen_q,seqlen_k",
|
||||||
|
[
|
||||||
|
(1, 1),
|
||||||
|
(1, 3),
|
||||||
|
(2, 1),
|
||||||
|
(511, 1),
|
||||||
|
(3, 513),
|
||||||
|
(64, 128),
|
||||||
|
(128, 128),
|
||||||
|
(256, 256),
|
||||||
|
(113, 203),
|
||||||
|
(128, 217),
|
||||||
|
(113, 211),
|
||||||
|
(108, 256),
|
||||||
|
(256, 512),
|
||||||
|
(307, 256),
|
||||||
|
(640, 128),
|
||||||
|
(512, 256),
|
||||||
|
(1024, 1024),
|
||||||
|
(1023, 1024),
|
||||||
|
(1024, 1023),
|
||||||
|
(2048, 2048),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_flash_attn_varlen_output(
|
||||||
|
seqlen_q,
|
||||||
|
seqlen_k,
|
||||||
|
d,
|
||||||
|
add_unused_qkv,
|
||||||
|
causal,
|
||||||
|
local,
|
||||||
|
softcap,
|
||||||
|
deterministic,
|
||||||
|
has_qv,
|
||||||
|
mha_type,
|
||||||
|
dtype,
|
||||||
|
):
|
||||||
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(seqlen_q + seqlen_k + d + int(causal) * 2 + int(local))
|
||||||
|
# batch_size = 40
|
||||||
|
# nheads = 16
|
||||||
|
batch_size = 9 if seqlen_q <= 2048 else 2
|
||||||
|
nheads = 6
|
||||||
|
# batch_size = 2
|
||||||
|
# nheads = 1
|
||||||
|
nheads_kv = nheads if mha_type == "mha" else (2 if mha_type == "gqa" else 1)
|
||||||
|
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||||
|
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
dv_vals = [d]
|
||||||
|
for dv in dv_vals:
|
||||||
|
q_ref = torch.randn(
|
||||||
|
batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref
|
||||||
|
)
|
||||||
|
if softcap > 0.0:
|
||||||
|
# Ensure the values of qk are at least within softcap range.
|
||||||
|
q_ref = (q_ref * softcap / 4).detach().requires_grad_()
|
||||||
|
q_ref = q_ref.to(dtype).to(dtype_ref).requires_grad_()
|
||||||
|
k_ref = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype_ref
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
.requires_grad_()
|
||||||
|
)
|
||||||
|
v_ref = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size, seqlen_k, nheads_kv, dv, device=device, dtype=dtype_ref
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
.requires_grad_()
|
||||||
|
)
|
||||||
|
if has_qv:
|
||||||
|
qv_ref = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qv_ref = None
|
||||||
|
# Put window_size after QKV randn so that window_size changes from test to test
|
||||||
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
q_descale, k_descale, v_descale = [
|
||||||
|
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
||||||
|
* 2
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
q_descale, k_descale, v_descale = None, None, None
|
||||||
|
q, k, v = [x.detach().requires_grad_() for x in (q_ref, k_ref, v_ref)]
|
||||||
|
qv = qv_ref.detach() if has_qv else None
|
||||||
|
query_padding_mask = generate_random_padding_mask(
|
||||||
|
seqlen_q, batch_size, device, mode="random", zero_lengths=False
|
||||||
|
)
|
||||||
|
key_padding_mask = generate_random_padding_mask(
|
||||||
|
seqlen_k, batch_size, device, mode="random", zero_lengths=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
|
||||||
|
if add_unused:
|
||||||
|
another_mask = generate_random_padding_mask(max_seq_len, bs, device)
|
||||||
|
attn_mask = torch.logical_and(padding_mask, another_mask)
|
||||||
|
unused_mask = torch.logical_xor(
|
||||||
|
torch.logical_or(padding_mask, another_mask), attn_mask
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_mask = padding_mask
|
||||||
|
unused_mask = None
|
||||||
|
return attn_mask, unused_mask
|
||||||
|
|
||||||
|
query_padding_mask, query_unused_mask = _gen_unused_masks(
|
||||||
|
query_padding_mask, add_unused_qkv, seqlen_q, batch_size, q.device
|
||||||
|
)
|
||||||
|
key_padding_mask, key_unused_mask = _gen_unused_masks(
|
||||||
|
key_padding_mask, add_unused_qkv, seqlen_k, batch_size, k.device
|
||||||
|
)
|
||||||
|
|
||||||
|
(
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
seqused_q,
|
||||||
|
seqused_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
output_pad_fn,
|
||||||
|
dq_pad_fn,
|
||||||
|
dk_pad_fn,
|
||||||
|
) = generate_qkv(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
query_padding_mask,
|
||||||
|
key_padding_mask,
|
||||||
|
kvpacked=False,
|
||||||
|
query_unused_mask=query_unused_mask,
|
||||||
|
key_unused_mask=key_unused_mask,
|
||||||
|
)
|
||||||
|
q_unpad, k_unpad, v_unpad = [
|
||||||
|
x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad)
|
||||||
|
]
|
||||||
|
out_ref, attn_ref = attention_ref(
|
||||||
|
q_ref,
|
||||||
|
k_ref,
|
||||||
|
v_ref,
|
||||||
|
query_padding_mask,
|
||||||
|
key_padding_mask,
|
||||||
|
causal=causal,
|
||||||
|
qv=qv_ref,
|
||||||
|
q_descale=q_descale,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
)
|
||||||
|
out_pt, attn_pt = attention_ref(
|
||||||
|
q_ref,
|
||||||
|
k_ref,
|
||||||
|
v_ref,
|
||||||
|
query_padding_mask,
|
||||||
|
key_padding_mask,
|
||||||
|
causal=causal,
|
||||||
|
qv=qv_ref,
|
||||||
|
q_descale=q_descale,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
upcast=False,
|
||||||
|
reorder_ops=True,
|
||||||
|
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||||
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||||
|
|
||||||
|
if query_unused_mask is not None:
|
||||||
|
q_zero_masking = rearrange(query_unused_mask, "b s -> b s 1 1")
|
||||||
|
|
||||||
|
# Numerical error if we just do any arithmetic on out_ref
|
||||||
|
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
|
||||||
|
rtol = 2 if softcap == 0.0 else 3
|
||||||
|
|
||||||
|
pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
|
||||||
|
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
|
||||||
|
for pack_gqa, num_splits in itertools.product(pack_gqa_vals, num_splits_vals):
|
||||||
|
out_unpad, lse, *rest = flash_attn_varlen_func(
|
||||||
|
q_unpad,
|
||||||
|
k_unpad,
|
||||||
|
v_unpad,
|
||||||
|
cu_seqlens_q,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen_q,
|
||||||
|
max_seqlen_k,
|
||||||
|
seqused_q=seqused_q,
|
||||||
|
seqused_k=seqused_k,
|
||||||
|
causal=causal,
|
||||||
|
q_descale=q_descale,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
out = output_pad_fn(out_unpad)
|
||||||
|
if query_unused_mask is not None:
|
||||||
|
out.masked_fill_(q_zero_masking, 0.0)
|
||||||
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||||
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||||
|
|
||||||
|
# Check that FlashAttention's numerical error is at most 3x the numerical error
|
||||||
|
# of a Pytorch implementation.
|
||||||
|
assert (out - out_ref).abs().max().item() <= rtol * (
|
||||||
|
out_pt - out_ref
|
||||||
|
).abs().max().item() + fwd_atol
|
||||||
|
|
||||||
|
if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv:
|
||||||
|
g_unpad = torch.randn_like(out_unpad)
|
||||||
|
do_o = ((g_unpad.float() * out_unpad.float()).sum(-1)).transpose(-1, -2)
|
||||||
|
dq_unpad, dk_unpad, dv_unpad = torch.autograd.grad(
|
||||||
|
out_unpad, (q_unpad, k_unpad, v_unpad), g_unpad
|
||||||
|
)
|
||||||
|
dq = dq_pad_fn(dq_unpad)
|
||||||
|
dk = dk_pad_fn(dk_unpad)
|
||||||
|
dv = dk_pad_fn(dv_unpad)
|
||||||
|
if key_unused_mask is not None:
|
||||||
|
k_zero_masking = rearrange(key_unused_mask, "b s -> b s 1 1")
|
||||||
|
dk.masked_fill_(k_zero_masking, 0.0)
|
||||||
|
dv.masked_fill_(k_zero_masking, 0.0)
|
||||||
|
if query_unused_mask is not None:
|
||||||
|
dq.masked_fill_(q_zero_masking, 0.0)
|
||||||
|
# print(f"dO_O max diff: {(softmax_d - do_o).abs().max().item()}")
|
||||||
|
# assert (softmax_d - do_o).abs().max().item() <= 1e-5
|
||||||
|
# assert dq_accum.abs().max().item() == 0.0
|
||||||
|
g = output_pad_fn(g_unpad)
|
||||||
|
|
||||||
|
# dq, dk, dv = torch.autograd.grad(out, (q, k, v), g)
|
||||||
|
dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q_ref, k_ref, v_ref), g)
|
||||||
|
dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q_ref, k_ref, v_ref), g)
|
||||||
|
print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
|
||||||
|
print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
|
||||||
|
print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
|
||||||
|
print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
|
||||||
|
print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
|
||||||
|
print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
|
||||||
|
print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
|
||||||
|
print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
|
||||||
|
print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
|
||||||
|
print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
|
||||||
|
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
|
||||||
|
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
|
||||||
|
|
||||||
|
if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not has_qv:
|
||||||
|
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (
|
||||||
|
0 if softcap == 0 else 3e-4
|
||||||
|
)
|
||||||
|
assert (dq - dq_ref).abs().max().item() <= rtol * (
|
||||||
|
dq_pt - dq_ref
|
||||||
|
).abs().max().item() + dq_atol
|
||||||
|
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (
|
||||||
|
0 if softcap == 0 else 3e-4
|
||||||
|
)
|
||||||
|
assert (dk - dk_ref).abs().max().item() <= rtol * (
|
||||||
|
dk_pt - dk_ref
|
||||||
|
).abs().max().item() + dk_atol
|
||||||
|
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (
|
||||||
|
0 if softcap == 0 else 3e-4
|
||||||
|
)
|
||||||
|
assert (dv - dv_ref).abs().max().item() <= rtol * (
|
||||||
|
dv_pt - dv_ref
|
||||||
|
).abs().max().item() + dv_atol
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user