[sgl-kernel] avoid per_token_quant_fp8.cu hardcode sm_count (#8738)
This commit is contained in:
@@ -173,9 +173,8 @@ void sgl_per_token_quant_fp8(torch::Tensor input, torch::Tensor output_q, torch:
|
||||
TORCH_CHECK(hidden_dim % 8 == 0, "Hidden dimension must be divisible by 8, but got ", hidden_dim);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
// Hard-code sm_count
|
||||
int sm_count = 132;
|
||||
constexpr int TOKENS_PER_CTA = 8;
|
||||
const int sm_count = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
|
||||
const int TOKENS_PER_CTA = 8;
|
||||
const bool use_warp_kernel = (num_tokens >= sm_count * 2 * TOKENS_PER_CTA);
|
||||
const bool use_vec16 = (hidden_dim % 16 == 0);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user