diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index f7690cb86..b0f8d57ee 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -750,9 +750,11 @@ def moe_align_block_size( by block_size for proper block matrix operations. """ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer( - max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) expert_ids = torch.empty( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device @@ -768,6 +770,9 @@ def moe_align_block_size( num_tokens_post_pad, ) else: + cumsum_buffer = torch.empty( + (num_experts + 1,), dtype=torch.int32, device=topk_ids.device + ) token_cnts_buffer = torch.empty( (num_experts + 1) * num_experts, dtype=torch.int32,