|
|
|
|
@@ -1,7 +1,7 @@
|
|
|
|
|
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
|
|
|
|
|
import itertools
|
|
|
|
|
import math
|
|
|
|
|
import os
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True
|
|
|
|
|
# or torch.cuda.get_device_capability("cuda")[0] < 9
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
|
|
DISABLE_SPLIT = True
|
|
|
|
|
DISABLE_SPLIT = False
|
|
|
|
|
DISABLE_PAGEDKV = True
|
|
|
|
|
DISABLE_APPENDKV = True
|
|
|
|
|
DISABLE_LOCAL = True
|
|
|
|
|
DISABLE_APPENDKV = False
|
|
|
|
|
DISABLE_LOCAL = False
|
|
|
|
|
DISABLE_SOFTCAP = True
|
|
|
|
|
DISABLE_PACKGQA = True
|
|
|
|
|
DISABLE_PACKGQA = False
|
|
|
|
|
DISABLE_FP16 = True
|
|
|
|
|
DISABLE_FP8 = True
|
|
|
|
|
|
|
|
|
|
@@ -199,6 +199,7 @@ def attention_ref(
|
|
|
|
|
v_descale=None,
|
|
|
|
|
window_size=(-1, -1), # -1 means infinite window size
|
|
|
|
|
sink_token_length=0,
|
|
|
|
|
sinks: Optional[torch.Tensor] = None,
|
|
|
|
|
softcap=0.0,
|
|
|
|
|
upcast=True,
|
|
|
|
|
reorder_ops=False,
|
|
|
|
|
@@ -271,7 +272,18 @@ def attention_ref(
|
|
|
|
|
scores.masked_fill_(local_mask, float("-inf"))
|
|
|
|
|
if attn_bias is not None:
|
|
|
|
|
scores = scores + attn_bias
|
|
|
|
|
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
|
|
|
|
if sinks is None:
|
|
|
|
|
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
|
|
|
|
else:
|
|
|
|
|
scores_fp32 = scores.to(torch.float32)
|
|
|
|
|
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
|
|
|
|
|
sinks = rearrange(sinks, "h -> h 1 1")
|
|
|
|
|
logits_or_sinks_max = torch.maximum(sinks, logits_max)
|
|
|
|
|
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
|
|
|
|
|
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
|
|
|
|
|
sinks - logits_or_sinks_max
|
|
|
|
|
)
|
|
|
|
|
attention = (unnormalized_scores / normalizer).to(v.dtype)
|
|
|
|
|
# We want to mask here so that the attention matrix doesn't have any NaNs
|
|
|
|
|
# Otherwise we'll get NaN in dV
|
|
|
|
|
if query_padding_mask is not None:
|
|
|
|
|
@@ -459,8 +471,10 @@ def generate_qkv(
|
|
|
|
|
)
|
|
|
|
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
|
|
|
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
|
|
|
|
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
|
|
@pytest.mark.parametrize("mha_type", ["mha"])
|
|
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
|
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
|
|
|
|
@pytest.mark.parametrize("has_sink", [False, True])
|
|
|
|
|
# @pytest.mark.parametrize("has_sink", [False])
|
|
|
|
|
@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
|
|
|
|
|
# @pytest.mark.parametrize("new_kv", [True])
|
|
|
|
|
# @pytest.mark.parametrize(
|
|
|
|
|
@@ -540,6 +554,7 @@ def test_flash_attn_kvcache(
|
|
|
|
|
new_kv,
|
|
|
|
|
mha_type,
|
|
|
|
|
dtype,
|
|
|
|
|
has_sink,
|
|
|
|
|
):
|
|
|
|
|
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
|
|
|
|
|
|
|
|
|
@@ -565,6 +580,12 @@ def test_flash_attn_kvcache(
|
|
|
|
|
assert nheads % nheads_k == 0
|
|
|
|
|
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
|
|
|
|
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
|
|
|
|
|
|
|
|
|
if has_sink:
|
|
|
|
|
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
|
|
|
|
else:
|
|
|
|
|
sinks = None
|
|
|
|
|
|
|
|
|
|
if dtype == torch.float8_e4m3fn or not is_hopper():
|
|
|
|
|
# for fp8 and ampere arch, we not support v head dim != qk head dim
|
|
|
|
|
dv_vals = [d]
|
|
|
|
|
@@ -820,6 +841,7 @@ def test_flash_attn_kvcache(
|
|
|
|
|
qv=qv,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
key_leftpad=cache_leftpad,
|
|
|
|
|
sinks=sinks,
|
|
|
|
|
)
|
|
|
|
|
out_pt, _ = attention_ref(
|
|
|
|
|
q_ro,
|
|
|
|
|
@@ -834,6 +856,7 @@ def test_flash_attn_kvcache(
|
|
|
|
|
reorder_ops=True,
|
|
|
|
|
key_leftpad=cache_leftpad,
|
|
|
|
|
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
|
|
|
|
sinks=sinks,
|
|
|
|
|
)
|
|
|
|
|
q = q.to(dtype)
|
|
|
|
|
q_unpad = q_unpad.to(dtype) if varlen_q else None
|
|
|
|
|
@@ -888,6 +911,7 @@ def test_flash_attn_kvcache(
|
|
|
|
|
scheduler_metadata=scheduler_metadata,
|
|
|
|
|
num_splits=num_splits,
|
|
|
|
|
return_softmax_lse=True,
|
|
|
|
|
sinks=sinks,
|
|
|
|
|
)
|
|
|
|
|
if varlen_q:
|
|
|
|
|
out = output_pad_fn(out)
|
|
|
|
|
@@ -1019,8 +1043,10 @@ def _generate_block_kvcache(
|
|
|
|
|
)
|
|
|
|
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
|
|
|
|
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
|
|
|
|
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
|
|
@pytest.mark.parametrize("mha_type", ["mha"])
|
|
|
|
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
|
|
|
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
|
|
|
|
@pytest.mark.parametrize("has_sink", [False, True])
|
|
|
|
|
# @pytest.mark.parametrize("has_sink", [False])
|
|
|
|
|
# @pytest.mark.parametrize("has_qv", [False, True])
|
|
|
|
|
@pytest.mark.parametrize("has_qv", [False])
|
|
|
|
|
# @pytest.mark.parametrize("deterministic", [False, True])
|
|
|
|
|
@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output(
|
|
|
|
|
has_qv,
|
|
|
|
|
mha_type,
|
|
|
|
|
dtype,
|
|
|
|
|
has_sink,
|
|
|
|
|
):
|
|
|
|
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
|
|
|
|
|
|
|
|
|
@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output(
|
|
|
|
|
qv_ref = None
|
|
|
|
|
# Put window_size after QKV randn so that window_size changes from test to test
|
|
|
|
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
|
|
|
|
|
|
|
|
|
if has_sink:
|
|
|
|
|
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
|
|
|
|
else:
|
|
|
|
|
sinks = None
|
|
|
|
|
|
|
|
|
|
if dtype == torch.float8_e4m3fn:
|
|
|
|
|
q_descale, k_descale, v_descale = [
|
|
|
|
|
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
|
|
|
|
@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output(
|
|
|
|
|
v_descale=v_descale,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
sinks=sinks,
|
|
|
|
|
)
|
|
|
|
|
out_pt, attn_pt = attention_ref(
|
|
|
|
|
q_ref,
|
|
|
|
|
@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output(
|
|
|
|
|
upcast=False,
|
|
|
|
|
reorder_ops=True,
|
|
|
|
|
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
|
|
|
|
sinks=sinks,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
|
|
|
|
@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output(
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
return_softmax_lse=True,
|
|
|
|
|
sinks=sinks,
|
|
|
|
|
)
|
|
|
|
|
out = output_pad_fn(out_unpad)
|
|
|
|
|
if query_unused_mask is not None:
|
|
|
|
|
|