[sgl-kernel] avoid per_token_quant_fp8.cu hardcode sm_count (#8738)

This commit is contained in:
Xiaoyu Zhang
2025-08-04 12:55:57 +08:00
committed by GitHub
parent f2d68ded6d
commit f57d2dc162

View File

@@ -173,9 +173,8 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// Hard-code sm_count
int sm_count = 132;
constexpr int TOKENS_PER_CTA = 8;
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
const int TOKENS_PER_CTA = 8;
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
const bool use_vec16 = (hidden_dim % 16 == 0);