[Fix] fix fa3 build at cu118 (#5036)

This commit is contained in:
yinfan98
2025-04-04 02:52:35 +08:00
committed by GitHub
parent 8e10fec9a8
commit b8b6008f47
8 changed files with 288 additions and 142 deletions

View File

@@ -10,7 +10,19 @@ from einops import rearrange, repeat
apply_rotary_emb = None
from sgl_kernel.flash_attn import flash_attn_with_kvcache
def is_fa3_supported(device=None) -> bool:
# FA3 can fail without a enough shared memory for a some shapes, currently
# only 8.0 and 8.7 have enough shared memory for all shapes
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
return (
(torch.cuda.get_device_capability(device)[0] >= 9)
and (torch.version.cuda >= "12.4")
# or torch.cuda.get_device_capability(device) == (8, 0)
# or torch.cuda.get_device_capability(device) == (8, 7)
)
DISABLE_BACKWARD = True
# For CI test, we close them to True.
@@ -284,6 +296,10 @@ def attention_ref(
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize(
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
@@ -372,6 +388,8 @@ def test_flash_attn_kvcache(
mha_type,
dtype,
):
from sgl_kernel.flash_attn import flash_attn_with_kvcache
if page_size is not None and seqlen_k % page_size != 0:
pytest.skip()
if seqlen_q > seqlen_k and new_kv: