reduce torch.zeros overhead in moe align block size kernel (#6369)
This commit is contained in:
@@ -197,8 +197,6 @@ void moe_align_block_size(
|
||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
||||
|
||||
cumsum_buffer.zero_();
|
||||
|
||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||
topk_ids.data_ptr<scalar_t>(),
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
|
||||
Reference in New Issue
Block a user