[Fix] fix fa3 build at cu118 (#5036)
This commit is contained in:
@@ -10,7 +10,19 @@ from einops import rearrange, repeat
|
||||
|
||||
apply_rotary_emb = None
|
||||
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
# FA3 can fail without a enough shared memory for a some shapes, currently
|
||||
# only 8.0 and 8.7 have enough shared memory for all shapes
|
||||
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
|
||||
# now sgl-kernel only build fa3 for sm90a && cuda >= 12.4
|
||||
return (
|
||||
(torch.cuda.get_device_capability(device)[0] >= 9)
|
||||
and (torch.version.cuda >= "12.4")
|
||||
# or torch.cuda.get_device_capability(device) == (8, 0)
|
||||
# or torch.cuda.get_device_capability(device) == (8, 7)
|
||||
)
|
||||
|
||||
|
||||
DISABLE_BACKWARD = True
|
||||
# For CI test, we close them to True.
|
||||
@@ -284,6 +296,10 @@ def attention_ref(
|
||||
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
|
||||
)
|
||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||
@pytest.mark.parametrize(
|
||||
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
|
||||
@@ -372,6 +388,8 @@ def test_flash_attn_kvcache(
|
||||
mha_type,
|
||||
dtype,
|
||||
):
|
||||
from sgl_kernel.flash_attn import flash_attn_with_kvcache
|
||||
|
||||
if page_size is not None and seqlen_k % page_size != 0:
|
||||
pytest.skip()
|
||||
if seqlen_q > seqlen_k and new_kv:
|
||||
|
||||
Reference in New Issue
Block a user