From 9798e72baac103a07553f2587477e92459cda3a2 Mon Sep 17 00:00:00 2001 From: yinfan98 <1106310035@qq.com> Date: Tue, 8 Apr 2025 12:35:14 +0800 Subject: [PATCH] [Misc] Use pytest.mark.skipif in sgl-kernel test (#5137) --- sgl-kernel/README.md | 11 ++++++++++- sgl-kernel/tests/test_fp4_gemm.py | 13 ++++++++----- sgl-kernel/tests/test_fp4_quantize.py | 16 +++++++++++----- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 73a7a0756..2afcfc3ce 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -158,10 +158,19 @@ python -m uv build --wheel -Cbuild-dir=build --color=always . ### Testing & Benchmarking -1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests) +1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests), if you need to skip some test, please use `@pytest.mark.skipif` + +```python +@pytest.mark.skipif( + skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." +) +``` + 2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) 3. Run test suite + + ### Release new version Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/version.py) diff --git a/sgl-kernel/tests/test_fp4_gemm.py b/sgl-kernel/tests/test_fp4_gemm.py index 5c092bd13..47401618b 100644 --- a/sgl-kernel/tests/test_fp4_gemm.py +++ b/sgl-kernel/tests/test_fp4_gemm.py @@ -2,11 +2,7 @@ import pytest import torch from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant -if torch.cuda.get_device_capability() < (10, 0): - pytest.skip( - reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True, - ) +skip_condition = torch.cuda.get_device_capability() < (10, 0) DTYPES = [torch.float16, torch.bfloat16] # m, n, k @@ -108,6 +104,9 @@ def get_ref_results( return torch.matmul(a_in_dtype, b_in_dtype.t()) +@pytest.mark.skipif( + skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." +) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @torch.inference_mode() @@ -149,3 +148,7 @@ def test_nvfp4_gemm( ) torch.testing.assert_close(out, expected_out.to(dtype=dtype), atol=1e-1, rtol=1e-1) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp4_quantize.py b/sgl-kernel/tests/test_fp4_quantize.py index 6b2489314..dcf09e053 100644 --- a/sgl-kernel/tests/test_fp4_quantize.py +++ b/sgl-kernel/tests/test_fp4_quantize.py @@ -2,11 +2,7 @@ import pytest import torch from sgl_kernel import scaled_fp4_quant -if torch.cuda.get_device_capability() < (10, 0): - pytest.skip( - reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True, - ) +skip_condition = torch.cuda.get_device_capability() < (10, 0) DTYPES = [torch.float16, torch.bfloat16] SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] @@ -115,6 +111,9 @@ def recover_swizzled_scales(scale, m, n): return result[:m, :scale_n] +@pytest.mark.skipif( + skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." +) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("shape", SHAPES) @torch.inference_mode() @@ -140,6 +139,9 @@ def test_quantize_to_fp4( torch.testing.assert_close(scale_ans, scale_ref) +@pytest.mark.skipif( + skip_condition, reason="Nvfp4 Requires compute capability of 10 or above." +) @pytest.mark.parametrize("pad_shape", PAD_SHAPES) @torch.inference_mode() def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: @@ -162,3 +164,7 @@ def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None: torch.testing.assert_close(out_ans, out_ref) torch.testing.assert_close(scale_ans, scale_ref) + + +if __name__ == "__main__": + pytest.main([__file__])