Apply fused sorted token ids padding (#8193)

This commit is contained in:
Ke Bao
2025-07-21 11:19:48 +08:00
committed by GitHub
parent 429bb0efa2
commit c9e8613c97

View File

@@ -752,14 +752,13 @@ def moe_align_block_size(
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
if enable_moe_align_block_size_triton:
sorted_ids.fill_(topk_ids.numel())
moe_align_block_size_triton(
topk_ids,
num_experts,
@@ -778,6 +777,11 @@ def moe_align_block_size(
device=topk_ids.device,
)
# Threshold based on benchmark results
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
if not fuse_sorted_ids_padding:
sorted_ids.fill_(topk_ids.numel())
sgl_moe_align_block_size(
topk_ids,
num_experts,
@@ -787,6 +791,7 @@ def moe_align_block_size(
num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer,
fuse_sorted_ids_padding,
)
return sorted_ids, expert_ids, num_tokens_post_pad