fix: only enable flash_attn test on sm80 sm90 (#7289)
This commit is contained in:
@@ -13,7 +13,7 @@ apply_rotary_emb = None
|
|||||||
|
|
||||||
def is_hopper():
|
def is_hopper():
|
||||||
# Only Hopper supports different V headdim
|
# Only Hopper supports different V headdim
|
||||||
return torch.cuda.get_device_properties(0).major >= 9
|
return torch.cuda.get_device_properties(0).major == 9
|
||||||
|
|
||||||
|
|
||||||
def is_fa3_supported(device=None) -> bool:
|
def is_fa3_supported(device=None) -> bool:
|
||||||
@@ -451,7 +451,7 @@ def generate_qkv(
|
|||||||
|
|
||||||
@pytest.mark.skipif(
|
@pytest.mark.skipif(
|
||||||
not is_fa3_supported(),
|
not is_fa3_supported(),
|
||||||
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
|
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||||
)
|
)
|
||||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -1009,6 +1009,10 @@ def _generate_block_kvcache(
|
|||||||
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
|
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_fa3_supported(),
|
||||||
|
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||||
|
)
|
||||||
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
|
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])
|
||||||
|
|||||||
@@ -8,9 +8,8 @@ from sgl_kernel.sparse_flash_attn import (
|
|||||||
convert_vertical_slash_indexes,
|
convert_vertical_slash_indexes,
|
||||||
convert_vertical_slash_indexes_mergehead,
|
convert_vertical_slash_indexes_mergehead,
|
||||||
sparse_attn_func,
|
sparse_attn_func,
|
||||||
sparse_attn_varlen_func,
|
|
||||||
)
|
)
|
||||||
from test_flash_attention import construct_local_mask
|
from test_flash_attention import construct_local_mask, is_fa3_supported
|
||||||
|
|
||||||
|
|
||||||
def ref_attn(
|
def ref_attn(
|
||||||
@@ -172,6 +171,10 @@ def ref_paged_attn(
|
|||||||
return torch.cat(outputs, dim=0)
|
return torch.cat(outputs, dim=0)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_fa3_supported(),
|
||||||
|
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seq_lens",
|
"seq_lens",
|
||||||
@@ -257,6 +260,10 @@ def test_sparse_attention(
|
|||||||
|
|
||||||
# sparse attention utils
|
# sparse attention utils
|
||||||
# origin
|
# origin
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_fa3_supported(),
|
||||||
|
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("causal", [True, False])
|
@pytest.mark.parametrize("causal", [True, False])
|
||||||
def test_convert_vertical_slash_indexes(causal):
|
def test_convert_vertical_slash_indexes(causal):
|
||||||
# Prepare small, hand-checkable inputs
|
# Prepare small, hand-checkable inputs
|
||||||
@@ -311,6 +318,10 @@ def test_convert_vertical_slash_indexes(causal):
|
|||||||
|
|
||||||
|
|
||||||
# mergehead
|
# mergehead
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not is_fa3_supported(),
|
||||||
|
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||||
|
)
|
||||||
@pytest.mark.parametrize("causal", [True, False])
|
@pytest.mark.parametrize("causal", [True, False])
|
||||||
def test_convert_vertical_slash_indexes_mergehead(causal):
|
def test_convert_vertical_slash_indexes_mergehead(causal):
|
||||||
# Prepare small, hand-checkable inputs for mergehead version
|
# Prepare small, hand-checkable inputs for mergehead version
|
||||||
|
|||||||
Reference in New Issue
Block a user