Update fa3 interface and add unit test (#9150)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True
|
||||
# or torch.cuda.get_device_capability("cuda")[0] < 9
|
||||
# )
|
||||
|
||||
DISABLE_SPLIT = True
|
||||
DISABLE_SPLIT = False
|
||||
DISABLE_PAGEDKV = True
|
||||
DISABLE_APPENDKV = True
|
||||
DISABLE_LOCAL = True
|
||||
DISABLE_APPENDKV = False
|
||||
DISABLE_LOCAL = False
|
||||
DISABLE_SOFTCAP = True
|
||||
DISABLE_PACKGQA = True
|
||||
DISABLE_PACKGQA = False
|
||||
DISABLE_FP16 = True
|
||||
DISABLE_FP8 = True
|
||||
|
||||
@@ -199,6 +199,7 @@ def attention_ref(
|
||||
v_descale=None,
|
||||
window_size=(-1, -1), # -1 means infinite window size
|
||||
sink_token_length=0,
|
||||
sinks: Optional[torch.Tensor] = None,
|
||||
softcap=0.0,
|
||||
upcast=True,
|
||||
reorder_ops=False,
|
||||
@@ -271,7 +272,18 @@ def attention_ref(
|
||||
scores.masked_fill_(local_mask, float("-inf"))
|
||||
if attn_bias is not None:
|
||||
scores = scores + attn_bias
|
||||
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
||||
if sinks is None:
|
||||
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
||||
else:
|
||||
scores_fp32 = scores.to(torch.float32)
|
||||
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
|
||||
sinks = rearrange(sinks, "h -> h 1 1")
|
||||
logits_or_sinks_max = torch.maximum(sinks, logits_max)
|
||||
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
|
||||
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
|
||||
sinks - logits_or_sinks_max
|
||||
)
|
||||
attention = (unnormalized_scores / normalizer).to(v.dtype)
|
||||
# We want to mask here so that the attention matrix doesn't have any NaNs
|
||||
# Otherwise we'll get NaN in dV
|
||||
if query_padding_mask is not None:
|
||||
@@ -459,8 +471,10 @@ def generate_qkv(
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
@pytest.mark.parametrize("mha_type", ["mha"])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||
@pytest.mark.parametrize("has_sink", [False, True])
|
||||
# @pytest.mark.parametrize("has_sink", [False])
|
||||
@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
|
||||
# @pytest.mark.parametrize("new_kv", [True])
|
||||
# @pytest.mark.parametrize(
|
||||
@@ -540,6 +554,7 @@ def test_flash_attn_kvcache(
|
||||
new_kv,
|
||||
mha_type,
|
||||
dtype,
|
||||
has_sink,
|
||||
):
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
|
||||
@@ -565,6 +580,12 @@ def test_flash_attn_kvcache(
|
||||
assert nheads % nheads_k == 0
|
||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||
|
||||
if has_sink:
|
||||
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
sinks = None
|
||||
|
||||
if dtype == torch.float8_e4m3fn or not is_hopper():
|
||||
# for fp8 and ampere arch, we not support v head dim != qk head dim
|
||||
dv_vals = [d]
|
||||
@@ -820,6 +841,7 @@ def test_flash_attn_kvcache(
|
||||
qv=qv,
|
||||
window_size=window_size,
|
||||
key_leftpad=cache_leftpad,
|
||||
sinks=sinks,
|
||||
)
|
||||
out_pt, _ = attention_ref(
|
||||
q_ro,
|
||||
@@ -834,6 +856,7 @@ def test_flash_attn_kvcache(
|
||||
reorder_ops=True,
|
||||
key_leftpad=cache_leftpad,
|
||||
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||
sinks=sinks,
|
||||
)
|
||||
q = q.to(dtype)
|
||||
q_unpad = q_unpad.to(dtype) if varlen_q else None
|
||||
@@ -888,6 +911,7 @@ def test_flash_attn_kvcache(
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
return_softmax_lse=True,
|
||||
sinks=sinks,
|
||||
)
|
||||
if varlen_q:
|
||||
out = output_pad_fn(out)
|
||||
@@ -1019,8 +1043,10 @@ def _generate_block_kvcache(
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
@pytest.mark.parametrize("mha_type", ["mha"])
|
||||
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||
@pytest.mark.parametrize("has_sink", [False, True])
|
||||
# @pytest.mark.parametrize("has_sink", [False])
|
||||
# @pytest.mark.parametrize("has_qv", [False, True])
|
||||
@pytest.mark.parametrize("has_qv", [False])
|
||||
# @pytest.mark.parametrize("deterministic", [False, True])
|
||||
@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output(
|
||||
has_qv,
|
||||
mha_type,
|
||||
dtype,
|
||||
has_sink,
|
||||
):
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
|
||||
@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output(
|
||||
qv_ref = None
|
||||
# Put window_size after QKV randn so that window_size changes from test to test
|
||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||
|
||||
if has_sink:
|
||||
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||
else:
|
||||
sinks = None
|
||||
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
q_descale, k_descale, v_descale = [
|
||||
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
||||
@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output(
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
sinks=sinks,
|
||||
)
|
||||
out_pt, attn_pt = attention_ref(
|
||||
q_ref,
|
||||
@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output(
|
||||
upcast=False,
|
||||
reorder_ops=True,
|
||||
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||
@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output(
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
return_softmax_lse=True,
|
||||
sinks=sinks,
|
||||
)
|
||||
out = output_pad_fn(out_unpad)
|
||||
if query_unused_mask is not None:
|
||||
|
||||
Reference in New Issue
Block a user