Update fa3 interface and add unit test (#9150)
This commit is contained in:
@@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
" Tensor? scheduler_metadata,"
|
" Tensor? scheduler_metadata,"
|
||||||
" int num_splits,"
|
" int num_splits,"
|
||||||
" bool? pack_gqa,"
|
" bool? pack_gqa,"
|
||||||
" int sm_margin) -> Tensor[]");
|
" int sm_margin,"
|
||||||
|
" Tensor? sinks) -> Tensor[]");
|
||||||
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
|
m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -82,4 +82,5 @@ std::vector<at::Tensor> mha_fwd(
|
|||||||
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
|
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
|
||||||
int num_splits,
|
int num_splits,
|
||||||
std::optional<bool> pack_gqa_,
|
std::optional<bool> pack_gqa_,
|
||||||
int const sm_margin);
|
int const sm_margin,
|
||||||
|
std::optional<const at::Tensor>& sinks_);
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ def flash_attn_with_kvcache(
|
|||||||
pack_gqa=None, # Can be tuned for speed
|
pack_gqa=None, # Can be tuned for speed
|
||||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||||
return_softmax_lse=False,
|
return_softmax_lse=False,
|
||||||
|
sinks=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||||
@@ -205,6 +206,7 @@ def flash_attn_with_kvcache(
|
|||||||
num_splits,
|
num_splits,
|
||||||
pack_gqa,
|
pack_gqa,
|
||||||
sm_margin,
|
sm_margin,
|
||||||
|
sinks,
|
||||||
)
|
)
|
||||||
# return (out, softmax_lse) if return_softmax_lse else out
|
# return (out, softmax_lse) if return_softmax_lse else out
|
||||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||||
@@ -232,6 +234,7 @@ def flash_attn_varlen_func(
|
|||||||
pack_gqa=None,
|
pack_gqa=None,
|
||||||
sm_margin=0,
|
sm_margin=0,
|
||||||
return_softmax_lse=False,
|
return_softmax_lse=False,
|
||||||
|
sinks=None,
|
||||||
):
|
):
|
||||||
if not is_fa3_supported():
|
if not is_fa3_supported():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -277,6 +280,7 @@ def flash_attn_varlen_func(
|
|||||||
num_splits=num_splits,
|
num_splits=num_splits,
|
||||||
pack_gqa=pack_gqa,
|
pack_gqa=pack_gqa,
|
||||||
sm_margin=sm_margin,
|
sm_margin=sm_margin,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
return (out, softmax_lse, *rest) if return_softmax_lse else out
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
|
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py
|
||||||
import itertools
|
import itertools
|
||||||
import math
|
import math
|
||||||
import os
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@@ -45,12 +45,12 @@ DISABLE_BACKWARD = True
|
|||||||
# or torch.cuda.get_device_capability("cuda")[0] < 9
|
# or torch.cuda.get_device_capability("cuda")[0] < 9
|
||||||
# )
|
# )
|
||||||
|
|
||||||
DISABLE_SPLIT = True
|
DISABLE_SPLIT = False
|
||||||
DISABLE_PAGEDKV = True
|
DISABLE_PAGEDKV = True
|
||||||
DISABLE_APPENDKV = True
|
DISABLE_APPENDKV = False
|
||||||
DISABLE_LOCAL = True
|
DISABLE_LOCAL = False
|
||||||
DISABLE_SOFTCAP = True
|
DISABLE_SOFTCAP = True
|
||||||
DISABLE_PACKGQA = True
|
DISABLE_PACKGQA = False
|
||||||
DISABLE_FP16 = True
|
DISABLE_FP16 = True
|
||||||
DISABLE_FP8 = True
|
DISABLE_FP8 = True
|
||||||
|
|
||||||
@@ -199,6 +199,7 @@ def attention_ref(
|
|||||||
v_descale=None,
|
v_descale=None,
|
||||||
window_size=(-1, -1), # -1 means infinite window size
|
window_size=(-1, -1), # -1 means infinite window size
|
||||||
sink_token_length=0,
|
sink_token_length=0,
|
||||||
|
sinks: Optional[torch.Tensor] = None,
|
||||||
softcap=0.0,
|
softcap=0.0,
|
||||||
upcast=True,
|
upcast=True,
|
||||||
reorder_ops=False,
|
reorder_ops=False,
|
||||||
@@ -271,7 +272,18 @@ def attention_ref(
|
|||||||
scores.masked_fill_(local_mask, float("-inf"))
|
scores.masked_fill_(local_mask, float("-inf"))
|
||||||
if attn_bias is not None:
|
if attn_bias is not None:
|
||||||
scores = scores + attn_bias
|
scores = scores + attn_bias
|
||||||
|
if sinks is None:
|
||||||
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
attention = torch.softmax(scores, dim=-1).to(v.dtype)
|
||||||
|
else:
|
||||||
|
scores_fp32 = scores.to(torch.float32)
|
||||||
|
logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True)
|
||||||
|
sinks = rearrange(sinks, "h -> h 1 1")
|
||||||
|
logits_or_sinks_max = torch.maximum(sinks, logits_max)
|
||||||
|
unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max)
|
||||||
|
normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp(
|
||||||
|
sinks - logits_or_sinks_max
|
||||||
|
)
|
||||||
|
attention = (unnormalized_scores / normalizer).to(v.dtype)
|
||||||
# We want to mask here so that the attention matrix doesn't have any NaNs
|
# We want to mask here so that the attention matrix doesn't have any NaNs
|
||||||
# Otherwise we'll get NaN in dV
|
# Otherwise we'll get NaN in dV
|
||||||
if query_padding_mask is not None:
|
if query_padding_mask is not None:
|
||||||
@@ -459,8 +471,10 @@ def generate_qkv(
|
|||||||
)
|
)
|
||||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||||
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||||
@pytest.mark.parametrize("mha_type", ["mha"])
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||||
|
@pytest.mark.parametrize("has_sink", [False, True])
|
||||||
|
# @pytest.mark.parametrize("has_sink", [False])
|
||||||
@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
|
@pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else []))
|
||||||
# @pytest.mark.parametrize("new_kv", [True])
|
# @pytest.mark.parametrize("new_kv", [True])
|
||||||
# @pytest.mark.parametrize(
|
# @pytest.mark.parametrize(
|
||||||
@@ -540,6 +554,7 @@ def test_flash_attn_kvcache(
|
|||||||
new_kv,
|
new_kv,
|
||||||
mha_type,
|
mha_type,
|
||||||
dtype,
|
dtype,
|
||||||
|
has_sink,
|
||||||
):
|
):
|
||||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||||
|
|
||||||
@@ -565,6 +580,12 @@ def test_flash_attn_kvcache(
|
|||||||
assert nheads % nheads_k == 0
|
assert nheads % nheads_k == 0
|
||||||
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
|
||||||
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
|
||||||
|
|
||||||
|
if has_sink:
|
||||||
|
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||||
|
else:
|
||||||
|
sinks = None
|
||||||
|
|
||||||
if dtype == torch.float8_e4m3fn or not is_hopper():
|
if dtype == torch.float8_e4m3fn or not is_hopper():
|
||||||
# for fp8 and ampere arch, we not support v head dim != qk head dim
|
# for fp8 and ampere arch, we not support v head dim != qk head dim
|
||||||
dv_vals = [d]
|
dv_vals = [d]
|
||||||
@@ -820,6 +841,7 @@ def test_flash_attn_kvcache(
|
|||||||
qv=qv,
|
qv=qv,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
key_leftpad=cache_leftpad,
|
key_leftpad=cache_leftpad,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
out_pt, _ = attention_ref(
|
out_pt, _ = attention_ref(
|
||||||
q_ro,
|
q_ro,
|
||||||
@@ -834,6 +856,7 @@ def test_flash_attn_kvcache(
|
|||||||
reorder_ops=True,
|
reorder_ops=True,
|
||||||
key_leftpad=cache_leftpad,
|
key_leftpad=cache_leftpad,
|
||||||
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
q = q.to(dtype)
|
q = q.to(dtype)
|
||||||
q_unpad = q_unpad.to(dtype) if varlen_q else None
|
q_unpad = q_unpad.to(dtype) if varlen_q else None
|
||||||
@@ -888,6 +911,7 @@ def test_flash_attn_kvcache(
|
|||||||
scheduler_metadata=scheduler_metadata,
|
scheduler_metadata=scheduler_metadata,
|
||||||
num_splits=num_splits,
|
num_splits=num_splits,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
if varlen_q:
|
if varlen_q:
|
||||||
out = output_pad_fn(out)
|
out = output_pad_fn(out)
|
||||||
@@ -1019,8 +1043,10 @@ def _generate_block_kvcache(
|
|||||||
)
|
)
|
||||||
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||||
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
# @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn])
|
||||||
# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"])
|
||||||
@pytest.mark.parametrize("mha_type", ["mha"])
|
# @pytest.mark.parametrize("mha_type", ["mha"])
|
||||||
|
@pytest.mark.parametrize("has_sink", [False, True])
|
||||||
|
# @pytest.mark.parametrize("has_sink", [False])
|
||||||
# @pytest.mark.parametrize("has_qv", [False, True])
|
# @pytest.mark.parametrize("has_qv", [False, True])
|
||||||
@pytest.mark.parametrize("has_qv", [False])
|
@pytest.mark.parametrize("has_qv", [False])
|
||||||
# @pytest.mark.parametrize("deterministic", [False, True])
|
# @pytest.mark.parametrize("deterministic", [False, True])
|
||||||
@@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output(
|
|||||||
has_qv,
|
has_qv,
|
||||||
mha_type,
|
mha_type,
|
||||||
dtype,
|
dtype,
|
||||||
|
has_sink,
|
||||||
):
|
):
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
@@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output(
|
|||||||
qv_ref = None
|
qv_ref = None
|
||||||
# Put window_size after QKV randn so that window_size changes from test to test
|
# Put window_size after QKV randn so that window_size changes from test to test
|
||||||
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,))
|
||||||
|
|
||||||
|
if has_sink:
|
||||||
|
sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device)
|
||||||
|
else:
|
||||||
|
sinks = None
|
||||||
|
|
||||||
if dtype == torch.float8_e4m3fn:
|
if dtype == torch.float8_e4m3fn:
|
||||||
q_descale, k_descale, v_descale = [
|
q_descale, k_descale, v_descale = [
|
||||||
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32)
|
||||||
@@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output(
|
|||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
out_pt, attn_pt = attention_ref(
|
out_pt, attn_pt = attention_ref(
|
||||||
q_ref,
|
q_ref,
|
||||||
@@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output(
|
|||||||
upcast=False,
|
upcast=False,
|
||||||
reorder_ops=True,
|
reorder_ops=True,
|
||||||
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
|
||||||
@@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output(
|
|||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
return_softmax_lse=True,
|
return_softmax_lse=True,
|
||||||
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
out = output_pad_fn(out_unpad)
|
out = output_pad_fn(out_unpad)
|
||||||
if query_unused_mask is not None:
|
if query_unused_mask is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user