reduce torch.zeros overhead in moe align block size kernel (#6369)
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.utils import (
|
||||
is_cuda,
|
||||
is_hip,
|
||||
log_info_on_rank0,
|
||||
next_power_of_2,
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -650,6 +651,61 @@ def moe_align_block_size_triton(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def init_sorted_ids_and_cumsum_buffer_kernel(
|
||||
sorted_ids_ptr,
|
||||
cumsum_buffer_ptr,
|
||||
max_num_tokens_padded,
|
||||
topk_ids_numel,
|
||||
num_experts: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
||||
|
||||
if pid < sorted_ids_blocks:
|
||||
mask = offsets < max_num_tokens_padded
|
||||
tl.store(
|
||||
sorted_ids_ptr + offsets,
|
||||
tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
|
||||
mask=mask,
|
||||
)
|
||||
elif pid == sorted_ids_blocks:
|
||||
offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
|
||||
mask_e = offset_e < num_experts + 1
|
||||
tl.store(
|
||||
cumsum_buffer_ptr + offset_e,
|
||||
tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
|
||||
mask=mask_e,
|
||||
)
|
||||
|
||||
|
||||
def init_sorted_ids_and_cumsum_buffer(
|
||||
max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
|
||||
):
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
||||
cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
|
||||
|
||||
BLOCK_SIZE = 1024
|
||||
sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
||||
grid = (sorted_ids_blocks + 1,)
|
||||
|
||||
init_sorted_ids_and_cumsum_buffer_kernel[grid](
|
||||
sorted_ids,
|
||||
cumsum_buffer,
|
||||
max_num_tokens_padded,
|
||||
topk_ids_numel,
|
||||
num_experts,
|
||||
BLOCK_SIZE,
|
||||
next_power_of_2(num_experts + 1),
|
||||
)
|
||||
|
||||
return sorted_ids, cumsum_buffer
|
||||
|
||||
|
||||
def moe_align_block_size(
|
||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@@ -691,10 +747,9 @@ def moe_align_block_size(
|
||||
by block_size for proper block matrix operations.
|
||||
"""
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
||||
sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
|
||||
max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
|
||||
)
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||
expert_ids = torch.empty(
|
||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||
@@ -715,9 +770,6 @@ def moe_align_block_size(
|
||||
dtype=torch.int32,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
cumsum_buffer = torch.empty(
|
||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
||||
)
|
||||
|
||||
sgl_moe_align_block_size(
|
||||
topk_ids,
|
||||
|
||||
Reference in New Issue
Block a user