reduce torch.zeros overhead in moe align block size kernel (#6369)

This commit is contained in:
Xiaoyu Zhang
2025-06-07 17:47:36 +08:00
committed by GitHub
parent 2a413829f4
commit 8b5f83ed3b
2 changed files with 58 additions and 8 deletions

View File

@@ -197,8 +197,6 @@ void moe_align_block_size(
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
cumsum_buffer.zero_();
align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),