diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index def092a34..0c7e854b9 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -13,7 +13,7 @@ apply_rotary_emb = None def is_hopper(): # Only Hopper supports different V headdim - return torch.cuda.get_device_properties(0).major >= 9 + return torch.cuda.get_device_properties(0).major == 9 def is_fa3_supported(device=None) -> bool: @@ -451,7 +451,7 @@ def generate_qkv( @pytest.mark.skipif( not is_fa3_supported(), - reason="flash_attn at sgl-kernel is only supported on sm90 and above", + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( @@ -1009,6 +1009,10 @@ def _generate_block_kvcache( return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize( "dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else []) diff --git a/sgl-kernel/tests/test_sparse_flash_attn.py b/sgl-kernel/tests/test_sparse_flash_attn.py index 4aa0a7c19..28c64cb61 100644 --- a/sgl-kernel/tests/test_sparse_flash_attn.py +++ b/sgl-kernel/tests/test_sparse_flash_attn.py @@ -8,9 +8,8 @@ from sgl_kernel.sparse_flash_attn import ( convert_vertical_slash_indexes, convert_vertical_slash_indexes_mergehead, sparse_attn_func, - sparse_attn_varlen_func, ) -from test_flash_attention import construct_local_mask +from test_flash_attention import construct_local_mask, is_fa3_supported def ref_attn( @@ -172,6 +171,10 @@ def ref_paged_attn( return torch.cat(outputs, dim=0) +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize( "seq_lens", @@ -257,6 +260,10 @@ def test_sparse_attention( # sparse attention utils # origin +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) @pytest.mark.parametrize("causal", [True, False]) def test_convert_vertical_slash_indexes(causal): # Prepare small, hand-checkable inputs @@ -311,6 +318,10 @@ def test_convert_vertical_slash_indexes(causal): # mergehead +@pytest.mark.skipif( + not is_fa3_supported(), + reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", +) @pytest.mark.parametrize("causal", [True, False]) def test_convert_vertical_slash_indexes_mergehead(causal): # Prepare small, hand-checkable inputs for mergehead version