fix: only enable flash_attn test on sm80 sm90 (#7289)
This commit is contained in:
@@ -8,9 +8,8 @@ from sgl_kernel.sparse_flash_attn import (
|
||||
convert_vertical_slash_indexes,
|
||||
convert_vertical_slash_indexes_mergehead,
|
||||
sparse_attn_func,
|
||||
sparse_attn_varlen_func,
|
||||
)
|
||||
from test_flash_attention import construct_local_mask
|
||||
from test_flash_attention import construct_local_mask, is_fa3_supported
|
||||
|
||||
|
||||
def ref_attn(
|
||||
@@ -172,6 +171,10 @@ def ref_paged_attn(
|
||||
return torch.cat(outputs, dim=0)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||
)
|
||||
@pytest.mark.parametrize("batch_size", [1, 2])
|
||||
@pytest.mark.parametrize(
|
||||
"seq_lens",
|
||||
@@ -257,6 +260,10 @@ def test_sparse_attention(
|
||||
|
||||
# sparse attention utils
|
||||
# origin
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_convert_vertical_slash_indexes(causal):
|
||||
# Prepare small, hand-checkable inputs
|
||||
@@ -311,6 +318,10 @@ def test_convert_vertical_slash_indexes(causal):
|
||||
|
||||
|
||||
# mergehead
|
||||
@pytest.mark.skipif(
|
||||
not is_fa3_supported(),
|
||||
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
|
||||
)
|
||||
@pytest.mark.parametrize("causal", [True, False])
|
||||
def test_convert_vertical_slash_indexes_mergehead(causal):
|
||||
# Prepare small, hand-checkable inputs for mergehead version
|
||||
|
||||
Reference in New Issue
Block a user