From f57d2dc162279bf976950f8b91cf86599f1dde09 Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Mon, 4 Aug 2025 12:55:57 +0800 Subject: [PATCH] [sgl-kernel] avoid per_token_quant_fp8.cu hardcode sm_count (#8738) --- sgl-kernel/csrc/gemm/per_token_quant_fp8.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index a3c60ad5b..c71022fd1 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -173,9 +173,8 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - // Hard-code sm_count - int sm_count = 132; - constexpr int TOKENS_PER_CTA = 8; + const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + const int TOKENS_PER_CTA = 8; const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); const bool use_vec16 = (hidden_dim % 16 == 0);