From 4c22ebe2e8ec185db939c339ec3e7e884dc77a45 Mon Sep 17 00:00:00 2001 From: hlu1 <14827759+hlu1@users.noreply.github.com> Date: Sat, 6 Sep 2025 01:35:18 -0700 Subject: [PATCH] Disable kernel cutlass_mla_decode on SM103 (#10058) Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com> --- sgl-kernel/csrc/attention/cutlass_mla_kernel.cu | 5 +++++ sgl-kernel/tests/test_cutlass_mla.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu index 6f4d46577..a41779c1b 100644 --- a/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu +++ b/sgl-kernel/csrc/attention/cutlass_mla_kernel.cu @@ -26,6 +26,7 @@ limitations under the License. #include "cutlass_sm100_mla/device/sm100_mla.hpp" #include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp" +#include "utils.h" // clang-format off #if !defined(CUDA_VERSION) || CUDA_VERSION < 12040 @@ -217,6 +218,10 @@ void cutlass_mla_decode( torch::Tensor const& workspace, double sm_scale, int64_t num_kv_splits) { + auto sm_version = getSMVersion(); + // On SM103a, half of the accuracy tests are failing. + TORCH_CHECK(sm_version == 100, "cutlass_mla_decode is only supported on compute capability 10.0, but found sm version ", sm_version); + auto in_dtype = q_nope.dtype(); at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); diff --git a/sgl-kernel/tests/test_cutlass_mla.py b/sgl-kernel/tests/test_cutlass_mla.py index 0f1829b5d..71de8327a 100644 --- a/sgl-kernel/tests/test_cutlass_mla.py +++ b/sgl-kernel/tests/test_cutlass_mla.py @@ -4,9 +4,10 @@ import torch.nn.functional as F from sgl_kernel import cutlass_mla_decode, cutlass_mla_get_workspace_size from torch import Tensor -if torch.cuda.get_device_capability() < (10, 0): +# Disable tests on SM103 until the accuracy issues are fixed. +if torch.cuda.get_device_capability() != (10, 0): pytest.skip( - reason="Cutlass MLA Requires compute capability of 10 or above.", + reason="Cutlass MLA Requires compute capability of 10.", allow_module_level=True, )