fix: only enable flash_attn test on sm80 sm90 (#7289)

This commit is contained in:
Yineng Zhang
2025-06-17 16:56:41 -07:00
committed by GitHub
parent fc554105f6
commit 0650e5176f
2 changed files with 19 additions and 4 deletions

View File

@@ -13,7 +13,7 @@ apply_rotary_emb = None
def is_hopper():
# Only Hopper supports different V headdim
return torch.cuda.get_device_properties(0).major >= 9
return torch.cuda.get_device_properties(0).major == 9
def is_fa3_supported(device=None) -> bool:
@@ -451,7 +451,7 @@ def generate_qkv(
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 and above",
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize(
@@ -1009,6 +1009,10 @@ def _generate_block_kvcache(
return k_cache, v_cache, page_table, k_cache_paged, v_cache_paged, num_blocks
@pytest.mark.skipif(
not is_fa3_supported(),
reason="flash_attn at sgl-kernel is only supported on sm90 or sm80",
)
# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
@pytest.mark.parametrize(
"dtype", [torch.bfloat16] + ([torch.float8_e4m3fn] if not DISABLE_FP8 else [])