diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index a39d6d5d3..246606746 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -752,14 +752,13 @@ def moe_align_block_size( sorted_ids = torch.empty( (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device ) - sorted_ids.fill_(topk_ids.numel()) - max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) expert_ids = torch.empty( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) if enable_moe_align_block_size_triton: + sorted_ids.fill_(topk_ids.numel()) moe_align_block_size_triton( topk_ids, num_experts, @@ -778,6 +777,11 @@ def moe_align_block_size( device=topk_ids.device, ) + # Threshold based on benchmark results + fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096 + if not fuse_sorted_ids_padding: + sorted_ids.fill_(topk_ids.numel()) + sgl_moe_align_block_size( topk_ids, num_experts, @@ -787,6 +791,7 @@ def moe_align_block_size( num_tokens_post_pad, token_cnts_buffer, cumsum_buffer, + fuse_sorted_ids_padding, ) return sorted_ids, expert_ids, num_tokens_post_pad