Apply fused sorted token ids padding (#8193)
This commit is contained in:
@@ -752,14 +752,13 @@ def moe_align_block_size(
|
|||||||
sorted_ids = torch.empty(
|
sorted_ids = torch.empty(
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
|
||||||
|
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
expert_ids = torch.empty(
|
expert_ids = torch.empty(
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
if enable_moe_align_block_size_triton:
|
if enable_moe_align_block_size_triton:
|
||||||
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
moe_align_block_size_triton(
|
moe_align_block_size_triton(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -778,6 +777,11 @@ def moe_align_block_size(
|
|||||||
device=topk_ids.device,
|
device=topk_ids.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Threshold based on benchmark results
|
||||||
|
fuse_sorted_ids_padding = sorted_ids.shape[0] <= 4096
|
||||||
|
if not fuse_sorted_ids_padding:
|
||||||
|
sorted_ids.fill_(topk_ids.numel())
|
||||||
|
|
||||||
sgl_moe_align_block_size(
|
sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
@@ -787,6 +791,7 @@ def moe_align_block_size(
|
|||||||
num_tokens_post_pad,
|
num_tokens_post_pad,
|
||||||
token_cnts_buffer,
|
token_cnts_buffer,
|
||||||
cumsum_buffer,
|
cumsum_buffer,
|
||||||
|
fuse_sorted_ids_padding,
|
||||||
)
|
)
|
||||||
return sorted_ids, expert_ids, num_tokens_post_pad
|
return sorted_ids, expert_ids, num_tokens_post_pad
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user