[Bug] Fix incorrect assertion in FA4 and add UT. (#11182)
This commit is contained in:
@@ -161,10 +161,7 @@ def flash_attn_with_kvcache(
|
|||||||
k is None and v is None
|
k is None and v is None
|
||||||
), "FA4 does not support updating KV cache in-place."
|
), "FA4 does not support updating KV cache in-place."
|
||||||
assert (
|
assert (
|
||||||
rotary_cos is None
|
rotary_cos is None and rotary_sin is None and rotary_seqlens is None
|
||||||
and rotary_sin is None
|
|
||||||
and rotary_interleaved is None
|
|
||||||
and rotary_seqlens is None
|
|
||||||
), "FA4 does not support rotary embedding."
|
), "FA4 does not support rotary embedding."
|
||||||
assert (
|
assert (
|
||||||
cache_batch_idx is None and cache_leftpad is None
|
cache_batch_idx is None and cache_leftpad is None
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from utils import is_hopper
|
from utils import is_hopper
|
||||||
|
|
||||||
flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4)
|
flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4)
|
||||||
|
flash_attn_with_kvcache = partial(flash_attn_with_kvcache, ver=4)
|
||||||
|
|
||||||
|
|
||||||
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
def unpad_input(hidden_states, attention_mask, unused_mask=None):
|
||||||
@@ -873,5 +874,578 @@ def test_flash_attn_varlen_output(
|
|||||||
).abs().max().item() + dv_atol
|
).abs().max().item() + dv_atol
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
is_hopper(),
|
||||||
|
reason="skip on hopper",
|
||||||
|
)
|
||||||
|
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||||
|
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
|
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||||
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||||
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||||
|
@pytest.mark.parametrize("has_learnable_sink", [False, True])
|
||||||
|
# @pytest.mark.parametrize("has_learnable_sink", [False])
|
||||||
|
# @pytest.mark.parametrize("new_kv", [False, True])
|
||||||
|
@pytest.mark.parametrize("new_kv", [False])
|
||||||
|
@pytest.mark.parametrize("local", [False, True])
|
||||||
|
# @pytest.mark.parametrize("local", [False])
|
||||||
|
# @pytest.mark.parametrize("causal", [False, True])
|
||||||
|
@pytest.mark.parametrize("causal", [True])
|
||||||
|
# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False])
|
||||||
|
@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [False])
|
||||||
|
# @pytest.mark.parametrize("has_rotary_seqlens", [False, True])
|
||||||
|
@pytest.mark.parametrize("has_rotary_seqlens", [False])
|
||||||
|
# @pytest.mark.parametrize("rotary_interleaved", [False, True])
|
||||||
|
@pytest.mark.parametrize("rotary_interleaved", [True])
|
||||||
|
# @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
|
||||||
|
@pytest.mark.parametrize("rotary_fraction", [0.0])
|
||||||
|
# @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128]))
|
||||||
|
@pytest.mark.parametrize("page_size", [None, 128])
|
||||||
|
# @pytest.mark.parametrize("page_size", [128])
|
||||||
|
# @pytest.mark.parametrize("has_leftpad", [False, True])
|
||||||
|
@pytest.mark.parametrize("has_leftpad", [False])
|
||||||
|
# @pytest.mark.parametrize("has_batch_idx", [False, True])
|
||||||
|
@pytest.mark.parametrize("has_batch_idx", [False])
|
||||||
|
# @pytest.mark.parametrize("varlen_q", [False, True])
|
||||||
|
@pytest.mark.parametrize("varlen_q", [False])
|
||||||
|
# @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
|
||||||
|
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
|
||||||
|
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
|
||||||
|
# @pytest.mark.parametrize('d', [56, 80])
|
||||||
|
# @pytest.mark.parametrize("d", [128])
|
||||||
|
@pytest.mark.parametrize("d", [64])
|
||||||
|
# @pytest.mark.parametrize("d", [192])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"seqlen_q,seqlen_k",
|
||||||
|
[
|
||||||
|
(1, 128),
|
||||||
|
(1, 339),
|
||||||
|
(3, 1024),
|
||||||
|
(64, 800),
|
||||||
|
(64, 256),
|
||||||
|
(3, 799),
|
||||||
|
(64, 2048),
|
||||||
|
(16, 20000),
|
||||||
|
# # (1, 128 * 1024),
|
||||||
|
# # (16, 128 * 1024),
|
||||||
|
# (128, 128),
|
||||||
|
# (256, 512), # To test appending KV with more than 1 block
|
||||||
|
# (2048, 3577), # Enough tile to test persistent scheduler
|
||||||
|
],
|
||||||
|
)
|
||||||
|
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
|
||||||
|
def test_flash_attn_kvcache(
|
||||||
|
seqlen_q,
|
||||||
|
seqlen_k,
|
||||||
|
d,
|
||||||
|
varlen_q,
|
||||||
|
has_batch_idx,
|
||||||
|
has_leftpad,
|
||||||
|
page_size,
|
||||||
|
rotary_fraction,
|
||||||
|
rotary_interleaved,
|
||||||
|
has_rotary_seqlens,
|
||||||
|
seqlen_new_eq_seqlen_q,
|
||||||
|
causal,
|
||||||
|
local,
|
||||||
|
new_kv,
|
||||||
|
has_learnable_sink,
|
||||||
|
mha_type,
|
||||||
|
dtype,
|
||||||
|
):
|
||||||
|
if page_size is not None and seqlen_k % page_size != 0:
|
||||||
|
pytest.skip()
|
||||||
|
if seqlen_q > seqlen_k and new_kv:
|
||||||
|
pytest.skip()
|
||||||
|
if not new_kv and rotary_fraction > 0.0:
|
||||||
|
pytest.skip()
|
||||||
|
if rotary_fraction == 0.0 and has_rotary_seqlens:
|
||||||
|
pytest.skip()
|
||||||
|
device = "cuda"
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
batch_size = 5
|
||||||
|
# batch_size = 1
|
||||||
|
batch_size_cache = batch_size if not has_batch_idx else batch_size * 2
|
||||||
|
nheads = 6
|
||||||
|
# nheads = 1
|
||||||
|
# rotary_dim must be a multiple of 16, and must be <= d
|
||||||
|
rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16
|
||||||
|
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3)
|
||||||
|
assert nheads % nheads_k == 0
|
||||||
|
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||||
|
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||||
|
dv_vals = [d]
|
||||||
|
if dtype == torch.float8_e4m3fn:
|
||||||
|
dv_vals = [d]
|
||||||
|
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0] if (causal or local) else [0]
|
||||||
|
attention_chunk_vals = [0]
|
||||||
|
for dv, attention_chunk in itertools.product(dv_vals, attention_chunk_vals):
|
||||||
|
# has_qv = d == 64 and dv >= 256
|
||||||
|
has_qv = False
|
||||||
|
q = (
|
||||||
|
torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype_ref)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
if has_qv:
|
||||||
|
qv = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size, seqlen_q, nheads, dv, device=device, dtype=dtype_ref
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
qv = None
|
||||||
|
if varlen_q:
|
||||||
|
query_padding_mask = generate_random_padding_mask(
|
||||||
|
seqlen_q, batch_size, device, mode="random"
|
||||||
|
)
|
||||||
|
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q, *rest = unpad_input(
|
||||||
|
q, query_padding_mask
|
||||||
|
)
|
||||||
|
output_pad_fn = lambda output_unpad: pad_input(
|
||||||
|
output_unpad, indices_q, batch_size, seqlen_q
|
||||||
|
)
|
||||||
|
qv_unpad = (
|
||||||
|
rearrange(qv, "b s ... -> (b s) ...")[indices_q] if has_qv else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_padding_mask = None
|
||||||
|
q_unpad = q
|
||||||
|
qv_unpad = qv
|
||||||
|
cu_seqlens_q, max_seqlen_q = None, None
|
||||||
|
# Put window_size after QKV randn so that window_size changes from test to test
|
||||||
|
window_size = (
|
||||||
|
(None, None) if not local else torch.randint(0, seqlen_k, (2,)).tolist()
|
||||||
|
)
|
||||||
|
if has_learnable_sink:
|
||||||
|
learnable_sink = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||||
|
else:
|
||||||
|
learnable_sink = None
|
||||||
|
|
||||||
|
seqlen_new = (
|
||||||
|
seqlen_q
|
||||||
|
if seqlen_new_eq_seqlen_q
|
||||||
|
else torch.randint(1, seqlen_q + 1, (1,)).item()
|
||||||
|
)
|
||||||
|
cu_seqlens_k_new = None
|
||||||
|
key_new_padding_mask = None
|
||||||
|
if new_kv:
|
||||||
|
k = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype_ref
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
v = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size, seqlen_new, nheads_k, dv, device=device, dtype=dtype_ref
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
if varlen_q: # k & v are also varlen
|
||||||
|
key_new_padding_mask = generate_random_padding_mask(
|
||||||
|
seqlen_new, batch_size, device, mode="random"
|
||||||
|
)
|
||||||
|
k_unpad, indices_k, cu_seqlens_k_new, *rest = unpad_input(
|
||||||
|
k, key_new_padding_mask
|
||||||
|
)
|
||||||
|
v_unpad, *rest = unpad_input(v, key_new_padding_mask)
|
||||||
|
else:
|
||||||
|
k_unpad, v_unpad = k, v
|
||||||
|
else:
|
||||||
|
k, v, k_unpad, v_unpad = None, None, None, None
|
||||||
|
if page_size is None:
|
||||||
|
k_cache = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size_cache,
|
||||||
|
seqlen_k,
|
||||||
|
nheads_k,
|
||||||
|
d,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype_ref,
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
v_cache = (
|
||||||
|
torch.randn(
|
||||||
|
batch_size_cache,
|
||||||
|
seqlen_k,
|
||||||
|
nheads_k,
|
||||||
|
dv,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype_ref,
|
||||||
|
)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
page_table = None
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
page_table,
|
||||||
|
k_cache_paged,
|
||||||
|
v_cache_paged,
|
||||||
|
num_blocks,
|
||||||
|
) = _generate_block_kvcache(
|
||||||
|
seqlen_k,
|
||||||
|
page_size,
|
||||||
|
batch_size_cache,
|
||||||
|
nheads_k,
|
||||||
|
d,
|
||||||
|
dv,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
dtype_ref,
|
||||||
|
)
|
||||||
|
cache_seqlens = torch.randint(
|
||||||
|
0 if new_kv else 1,
|
||||||
|
# If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough
|
||||||
|
(
|
||||||
|
(
|
||||||
|
seqlen_k
|
||||||
|
- (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new)
|
||||||
|
+ 1
|
||||||
|
)
|
||||||
|
if new_kv
|
||||||
|
else (seqlen_k + 1)
|
||||||
|
),
|
||||||
|
(batch_size,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if has_leftpad:
|
||||||
|
cache_leftpad = torch.cat(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
torch.randint(
|
||||||
|
0,
|
||||||
|
cache_seqlens[i].item(),
|
||||||
|
(1,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
if cache_seqlens[i].item() > 0
|
||||||
|
else torch.zeros(1, dtype=torch.int32, device=device)
|
||||||
|
)
|
||||||
|
for i in range(batch_size)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cache_leftpad = None
|
||||||
|
if has_batch_idx:
|
||||||
|
cache_batch_idx = torch.randperm(
|
||||||
|
batch_size_cache, dtype=torch.int32, device=device
|
||||||
|
)[:batch_size]
|
||||||
|
else:
|
||||||
|
cache_batch_idx = None
|
||||||
|
arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s")
|
||||||
|
cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1")
|
||||||
|
if not new_kv:
|
||||||
|
key_padding_mask = arange < cache_seqlens_expanded
|
||||||
|
else:
|
||||||
|
k_new_seqlens = (
|
||||||
|
key_new_padding_mask.sum(-1, keepdims=True) if varlen_q else seqlen_new
|
||||||
|
)
|
||||||
|
key_padding_mask = arange < cache_seqlens_expanded + k_new_seqlens
|
||||||
|
if has_leftpad:
|
||||||
|
key_padding_mask = torch.logical_and(
|
||||||
|
key_padding_mask,
|
||||||
|
arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k),
|
||||||
|
)
|
||||||
|
# cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device)
|
||||||
|
rotary_seqlens = cache_seqlens if not has_rotary_seqlens else cache_seqlens // 2
|
||||||
|
if rotary_dim > 0:
|
||||||
|
angle = (
|
||||||
|
torch.rand(
|
||||||
|
seqlen_k if page_size is None else num_blocks * page_size,
|
||||||
|
rotary_dim // 2,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
* 2
|
||||||
|
* math.pi
|
||||||
|
)
|
||||||
|
cos = torch.cos(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)
|
||||||
|
sin = torch.sin(angle).to(dtype=dtype_ref).to(dtype).to(dtype_ref)
|
||||||
|
if causal or local:
|
||||||
|
q_ro = apply_rotary_emb(
|
||||||
|
q,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
seqlen_offsets=rotary_seqlens,
|
||||||
|
interleaved=rotary_interleaved,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
q_ro = rearrange(
|
||||||
|
apply_rotary_emb(
|
||||||
|
rearrange(q, "b s h d -> b 1 (s h) d"),
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
seqlen_offsets=rotary_seqlens,
|
||||||
|
interleaved=rotary_interleaved,
|
||||||
|
),
|
||||||
|
"b 1 (s h) d -> b s h d",
|
||||||
|
s=seqlen_q,
|
||||||
|
)
|
||||||
|
# q_ro = q
|
||||||
|
k_ro = apply_rotary_emb(
|
||||||
|
k,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
seqlen_offsets=rotary_seqlens,
|
||||||
|
interleaved=rotary_interleaved,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cos, sin = None, None
|
||||||
|
q_ro, k_ro = q, k
|
||||||
|
# k_cache[:, 64:] = -1
|
||||||
|
k_cache_ref = (
|
||||||
|
k_cache if not has_batch_idx else k_cache[cache_batch_idx]
|
||||||
|
).clone()
|
||||||
|
v_cache_ref = (
|
||||||
|
v_cache if not has_batch_idx else v_cache[cache_batch_idx]
|
||||||
|
).clone()
|
||||||
|
if new_kv:
|
||||||
|
update_mask = torch.logical_and(
|
||||||
|
cache_seqlens_expanded <= arange,
|
||||||
|
arange < cache_seqlens_expanded + k_new_seqlens,
|
||||||
|
)
|
||||||
|
k_to_update = rearrange(k_ro, "b s ... -> (b s) ...")
|
||||||
|
v_to_update = rearrange(v, "b s ... -> (b s) ...")
|
||||||
|
if varlen_q:
|
||||||
|
k_to_update = k_to_update[indices_k]
|
||||||
|
v_to_update = v_to_update[indices_k]
|
||||||
|
k_cache_ref[update_mask] = k_to_update
|
||||||
|
v_cache_ref[update_mask] = v_to_update
|
||||||
|
k_cache_rep = repeat(
|
||||||
|
k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k
|
||||||
|
)
|
||||||
|
v_cache_rep = repeat(
|
||||||
|
v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k
|
||||||
|
)
|
||||||
|
out_ref, _ = attention_ref(
|
||||||
|
q_ro,
|
||||||
|
k_cache_rep,
|
||||||
|
v_cache_rep,
|
||||||
|
query_padding_mask,
|
||||||
|
key_padding_mask,
|
||||||
|
causal=causal,
|
||||||
|
qv=qv,
|
||||||
|
window_size=window_size,
|
||||||
|
learnable_sink=learnable_sink,
|
||||||
|
attention_chunk=attention_chunk,
|
||||||
|
key_leftpad=cache_leftpad,
|
||||||
|
)
|
||||||
|
out_pt, _ = attention_ref(
|
||||||
|
q_ro,
|
||||||
|
k_cache_rep,
|
||||||
|
v_cache_rep,
|
||||||
|
query_padding_mask,
|
||||||
|
key_padding_mask,
|
||||||
|
causal=causal,
|
||||||
|
qv=qv,
|
||||||
|
window_size=window_size,
|
||||||
|
learnable_sink=learnable_sink,
|
||||||
|
attention_chunk=attention_chunk,
|
||||||
|
upcast=False,
|
||||||
|
reorder_ops=True,
|
||||||
|
key_leftpad=cache_leftpad,
|
||||||
|
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||||
|
)
|
||||||
|
q = q.to(dtype)
|
||||||
|
q_unpad = q_unpad.to(dtype) if varlen_q else None
|
||||||
|
k_cache = k_cache.to(dtype)
|
||||||
|
v_cache = v_cache.to(dtype)
|
||||||
|
k_cache_paged = k_cache_paged.to(dtype) if page_size is not None else None
|
||||||
|
v_cache_paged = v_cache_paged.to(dtype) if page_size is not None else None
|
||||||
|
k = k.to(dtype) if k is not None else None
|
||||||
|
v = v.to(dtype) if v is not None else None
|
||||||
|
k_unpad = k_unpad.to(dtype) if k_unpad is not None else None
|
||||||
|
v_unpad = v_unpad.to(dtype) if v_unpad is not None else None
|
||||||
|
qv = qv.to(dtype) if qv is not None else None
|
||||||
|
qv_unpad = qv_unpad.to(dtype) if (varlen_q and qv is not None) else None
|
||||||
|
cos = cos.to(dtype) if cos is not None else None
|
||||||
|
sin = sin.to(dtype) if sin is not None else None
|
||||||
|
k_cache_saved = k_cache.clone() if page_size is None else k_cache_paged.clone()
|
||||||
|
v_cache_saved = v_cache.clone() if page_size is None else v_cache_paged.clone()
|
||||||
|
# num_splits_vals = [1, 0]
|
||||||
|
num_splits_vals = [1]
|
||||||
|
# precompute_metadata_vals = [False, True]
|
||||||
|
precompute_metadata_vals = [False]
|
||||||
|
for num_splits, precompute_metadata in itertools.product(
|
||||||
|
num_splits_vals, precompute_metadata_vals
|
||||||
|
):
|
||||||
|
# if precompute_metadata:
|
||||||
|
# scheduler_metadata = get_scheduler_metadata(
|
||||||
|
# batch_size, max_seqlen_q if varlen_q else seqlen_q, seqlen_k, nheads, nheads_k, d,
|
||||||
|
# cache_seqlens, q.dtype, headdim_v=dv, cu_seqlens_q=cu_seqlens_q,
|
||||||
|
# cu_seqlens_k_new=cu_seqlens_k_new, cache_leftpad=cache_leftpad,
|
||||||
|
# max_seqlen_k_new=seqlen_new, page_size=page_size,
|
||||||
|
# causal=causal, window_size=window_size, attention_chunk=attention_chunk,
|
||||||
|
# num_splits=num_splits
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# scheduler_metadata = None
|
||||||
|
scheduler_metadata = None
|
||||||
|
# Repeat to test metadata reuse
|
||||||
|
for _ in range(1 if not precompute_metadata else 2):
|
||||||
|
if page_size is None:
|
||||||
|
k_cache.copy_(k_cache_saved)
|
||||||
|
v_cache.copy_(v_cache_saved)
|
||||||
|
else:
|
||||||
|
k_cache_paged.copy_(k_cache_saved)
|
||||||
|
v_cache_paged.copy_(v_cache_saved)
|
||||||
|
# out, lse, *rest = flash_attn_with_kvcache(
|
||||||
|
out, lse, *rest = flash_attn_with_kvcache(
|
||||||
|
q if not varlen_q else q_unpad,
|
||||||
|
k_cache if page_size is None else k_cache_paged,
|
||||||
|
v_cache if page_size is None else v_cache_paged,
|
||||||
|
# k if not new_kv or not varlen_q else k_unpad,
|
||||||
|
# v if not new_kv or not varlen_q else v_unpad,
|
||||||
|
# qv=qv if not varlen_q else qv_unpad,
|
||||||
|
# rotary_cos=cos,
|
||||||
|
# rotary_sin=sin,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
# cache_batch_idx=cache_batch_idx,
|
||||||
|
# cache_leftpad=cache_leftpad,
|
||||||
|
page_table=page_table,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
# cu_seqlens_k_new=cu_seqlens_k_new,
|
||||||
|
# rotary_seqlens=rotary_seqlens,
|
||||||
|
causal=causal,
|
||||||
|
window_size=window_size,
|
||||||
|
sinks=learnable_sink,
|
||||||
|
# attention_chunk=attention_chunk,
|
||||||
|
# rotary_interleaved=rotary_interleaved,
|
||||||
|
# scheduler_metadata=scheduler_metadata,
|
||||||
|
# num_splits=num_splits,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
if varlen_q:
|
||||||
|
out = output_pad_fn(out)
|
||||||
|
# out = flash_attn_with_kvcache(
|
||||||
|
# q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size
|
||||||
|
# )
|
||||||
|
# out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size)
|
||||||
|
# qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref)
|
||||||
|
# m = qk.amax(-1, keepdim=True)
|
||||||
|
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
|
||||||
|
# o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref)
|
||||||
|
# lse_ref = torch.logsumexp(qk / math.sqrt(d), -1)
|
||||||
|
# probs = torch.softmax(qk, dim=-1)
|
||||||
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||||
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||||
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||||
|
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
|
||||||
|
# breakpoint()
|
||||||
|
|
||||||
|
# Check that FlashAttention's numerical error is at most twice the numerical error
|
||||||
|
# of a Pytorch implementation.
|
||||||
|
if new_kv:
|
||||||
|
if page_size is None:
|
||||||
|
k_cache_select = (
|
||||||
|
k_cache.to(dtype_ref)
|
||||||
|
if not has_batch_idx
|
||||||
|
else k_cache.to(dtype_ref)[cache_batch_idx]
|
||||||
|
)
|
||||||
|
v_cache_select = (
|
||||||
|
v_cache.to(dtype_ref)
|
||||||
|
if not has_batch_idx
|
||||||
|
else v_cache.to(dtype_ref)[cache_batch_idx]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
k_cache_select = rearrange(
|
||||||
|
k_cache_paged.to(dtype_ref)[
|
||||||
|
(
|
||||||
|
page_table
|
||||||
|
if not has_batch_idx
|
||||||
|
else page_table[cache_batch_idx]
|
||||||
|
).flatten()
|
||||||
|
],
|
||||||
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||||
|
b=batch_size,
|
||||||
|
)[:, :seqlen_k].to(dtype_ref)
|
||||||
|
v_cache_select = rearrange(
|
||||||
|
v_cache_paged.to(dtype_ref)[
|
||||||
|
(
|
||||||
|
page_table
|
||||||
|
if not has_batch_idx
|
||||||
|
else page_table[cache_batch_idx]
|
||||||
|
).flatten()
|
||||||
|
],
|
||||||
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||||
|
b=batch_size,
|
||||||
|
)[:, :seqlen_k].to(dtype_ref)
|
||||||
|
k_cache_ref = k_cache_ref.to(dtype).to(dtype_ref)
|
||||||
|
v_cache_ref = v_cache_ref.to(dtype).to(dtype_ref)
|
||||||
|
if dtype is not torch.float8_e4m3fn:
|
||||||
|
assert torch.equal(v_cache_select, v_cache_ref)
|
||||||
|
else:
|
||||||
|
assert torch.allclose(
|
||||||
|
v_cache_select, v_cache_ref, rtol=1e-3, atol=1e-3
|
||||||
|
)
|
||||||
|
# breakpoint()
|
||||||
|
# if rotary_dim == 0 and dtype is not torch.float8_e4m3fn:
|
||||||
|
if rotary_dim == 0:
|
||||||
|
assert torch.equal(k_cache_select, k_cache_ref)
|
||||||
|
else:
|
||||||
|
# if not torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3):
|
||||||
|
# breakpoint()
|
||||||
|
if dtype is not torch.float8_e4m3fn:
|
||||||
|
assert torch.allclose(
|
||||||
|
k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert torch.allclose(
|
||||||
|
k_cache_select, k_cache_ref, rtol=1e-1, atol=1e-1
|
||||||
|
)
|
||||||
|
mult = 4 if dtype == torch.float8_e4m3fn else 2
|
||||||
|
assert (out - out_ref).abs().max().item() <= mult * (
|
||||||
|
out_pt - out_ref
|
||||||
|
).abs().max().item() + 1e-5
|
||||||
|
mult_mean = 3 if dtype == torch.float8_e4m3fn else 1.5
|
||||||
|
assert (out - out_ref).abs().mean().item() <= mult_mean * (
|
||||||
|
out_pt - out_ref
|
||||||
|
).abs().mean().item()
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_block_kvcache(
|
||||||
|
seqlen_k, page_size, batch_size, nheads_k, d, dv, device, dtype, dtype_ref
|
||||||
|
):
|
||||||
|
num_blocks = math.ceil(seqlen_k / page_size) * batch_size * 3
|
||||||
|
k_cache_paged = (
|
||||||
|
torch.randn(num_blocks, page_size, nheads_k, d, device=device, dtype=dtype_ref)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
v_cache_paged = (
|
||||||
|
torch.randn(num_blocks, page_size, nheads_k, dv, device=device, dtype=dtype_ref)
|
||||||
|
.to(dtype)
|
||||||
|
.to(dtype_ref)
|
||||||
|
)
|
||||||
|
page_table = rearrange(
|
||||||
|
torch.randperm(num_blocks, dtype=torch.int32, device=device),
|
||||||
|
"(b nblocks) -> b nblocks",
|
||||||
|
b=batch_size,
|
||||||
|
)
|
||||||
|
k_cache = rearrange(
|
||||||
|
k_cache_paged[page_table.flatten()],
|
||||||
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||||
|
b=batch_size,
|
||||||
|
)[:, :seqlen_k]
|
||||||
|
v_cache = rearrange(
|
||||||
|
v_cache_paged[page_table.flatten()],
|
||||||
|
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
|
||||||
|
b=batch_size,
|
||||||
|
)[:, :seqlen_k]
|
||||||
|
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
|||||||
Reference in New Issue
Block a user