diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index a3c60ad5b..c71022fd1 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -173,9 +173,8 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch: TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - // Hard-code sm_count - int sm_count = 132; - constexpr int TOKENS_PER_CTA = 8; + const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + const int TOKENS_PER_CTA = 8; const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA); const bool use_vec16 = (hidden_dim % 16 == 0);