From a7d825fccc378e5876bac48b462035a7fedf667e Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Thu, 28 Aug 2025 20:00:32 -0700 Subject: [PATCH] Skip some tests on Blackwell (#9777) Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> --- sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py | 9 +++++++++ sgl-kernel/tests/test_int8_gemm.py | 5 +++++ sgl-kernel/tests/utils.py | 9 +++++++++ 3 files changed, 23 insertions(+) create mode 100644 sgl-kernel/tests/utils.py diff --git a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py index 3cdd62edd..506f8301a 100644 --- a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py +++ b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py @@ -1,6 +1,7 @@ import pytest import torch from sgl_kernel import cutlass_w4a8_moe_mm +from utils import is_hopper def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor: @@ -38,6 +39,10 @@ def pack_interleave(num_experts, ref_weight, ref_scale): return w_q, w_scale +@pytest.mark.skipif( + not is_hopper(), + reason="cutlass_w4a8_moe_mm is only supported on sm90", +) @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) def test_int4_fp8_grouped_gemm_single_expert(batch_size): # Test parameters @@ -127,6 +132,10 @@ def test_int4_fp8_grouped_gemm_single_expert(batch_size): raise +@pytest.mark.skipif( + not is_hopper(), + reason="cutlass_w4a8_moe_mm is only supported on sm90", +) @pytest.mark.parametrize("batch_size", [2, 4, 8, 16]) @pytest.mark.parametrize("k", [512, 1024]) @pytest.mark.parametrize("n", [1024, 2048]) diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 4d506faed..80f32cd02 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -1,6 +1,7 @@ import pytest import torch from sgl_kernel import int8_scaled_mm +from utils import is_sm10x def to_int8(tensor: torch.Tensor) -> torch.Tensor: @@ -30,6 +31,10 @@ def _test_accuracy_once(M, N, K, with_bias, out_dtype, device): torch.testing.assert_close(o, o1) +@pytest.mark.skipif( + is_sm10x(), + reason="int8_scaled_mm is only supported on sm90 and lower", +) @pytest.mark.parametrize("M", [1, 16, 32, 64, 128, 512, 1024, 4096, 8192]) @pytest.mark.parametrize("N", [16, 128, 512, 1024, 4096, 8192, 16384]) @pytest.mark.parametrize("K", [512, 1024, 4096, 8192, 16384]) diff --git a/sgl-kernel/tests/utils.py b/sgl-kernel/tests/utils.py new file mode 100644 index 000000000..8fa9a2234 --- /dev/null +++ b/sgl-kernel/tests/utils.py @@ -0,0 +1,9 @@ +import torch + + +def is_sm10x(): + return torch.cuda.get_device_capability() >= (10, 0) + + +def is_hopper(): + return torch.cuda.get_device_capability() == (9, 0)