optimize moe_align_kernel cuda (#3347)

This commit is contained in:
Xiaoyu Zhang
2025-02-07 00:53:46 +08:00
committed by GitHub
parent adeee15204
commit cdae77b03d
3 changed files with 29 additions and 21 deletions

View File

@@ -417,12 +417,12 @@ def moe_align_block_size(
num_tokens_post_pad,
)
else:
token_cnts_buffer = torch.empty(
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)