reduce torch.zeros overhead in moe align block size kernel (#6369)
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.utils import (
|
|||||||
is_cuda,
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
|
next_power_of_2,
|
||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
@@ -650,6 +651,61 @@ def moe_align_block_size_triton(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def init_sorted_ids_and_cumsum_buffer_kernel(
|
||||||
|
sorted_ids_ptr,
|
||||||
|
cumsum_buffer_ptr,
|
||||||
|
max_num_tokens_padded,
|
||||||
|
topk_ids_numel,
|
||||||
|
num_experts: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
|
||||||
|
):
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
|
|
||||||
|
sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
||||||
|
|
||||||
|
if pid < sorted_ids_blocks:
|
||||||
|
mask = offsets < max_num_tokens_padded
|
||||||
|
tl.store(
|
||||||
|
sorted_ids_ptr + offsets,
|
||||||
|
tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
|
||||||
|
mask=mask,
|
||||||
|
)
|
||||||
|
elif pid == sorted_ids_blocks:
|
||||||
|
offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
|
||||||
|
mask_e = offset_e < num_experts + 1
|
||||||
|
tl.store(
|
||||||
|
cumsum_buffer_ptr + offset_e,
|
||||||
|
tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
|
||||||
|
mask=mask_e,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def init_sorted_ids_and_cumsum_buffer(
|
||||||
|
max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
|
||||||
|
):
|
||||||
|
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
|
||||||
|
cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
BLOCK_SIZE = 1024
|
||||||
|
sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
|
||||||
|
grid = (sorted_ids_blocks + 1,)
|
||||||
|
|
||||||
|
init_sorted_ids_and_cumsum_buffer_kernel[grid](
|
||||||
|
sorted_ids,
|
||||||
|
cumsum_buffer,
|
||||||
|
max_num_tokens_padded,
|
||||||
|
topk_ids_numel,
|
||||||
|
num_experts,
|
||||||
|
BLOCK_SIZE,
|
||||||
|
next_power_of_2(num_experts + 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
return sorted_ids, cumsum_buffer
|
||||||
|
|
||||||
|
|
||||||
def moe_align_block_size(
|
def moe_align_block_size(
|
||||||
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
topk_ids: torch.Tensor, block_size: int, num_experts: int
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
@@ -691,10 +747,9 @@ def moe_align_block_size(
|
|||||||
by block_size for proper block matrix operations.
|
by block_size for proper block matrix operations.
|
||||||
"""
|
"""
|
||||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||||
sorted_ids = torch.empty(
|
sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
|
||||||
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
|
max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
|
||||||
)
|
)
|
||||||
sorted_ids.fill_(topk_ids.numel())
|
|
||||||
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
|
||||||
expert_ids = torch.empty(
|
expert_ids = torch.empty(
|
||||||
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
|
||||||
@@ -715,9 +770,6 @@ def moe_align_block_size(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=topk_ids.device,
|
device=topk_ids.device,
|
||||||
)
|
)
|
||||||
cumsum_buffer = torch.empty(
|
|
||||||
num_experts + 1, dtype=torch.int32, device=topk_ids.device
|
|
||||||
)
|
|
||||||
|
|
||||||
sgl_moe_align_block_size(
|
sgl_moe_align_block_size(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
|
|||||||
@@ -197,8 +197,6 @@ void moe_align_block_size(
|
|||||||
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
|
||||||
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
|
||||||
|
|
||||||
cumsum_buffer.zero_();
|
|
||||||
|
|
||||||
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
align_kernel<<<1, threads, shared_mem_size, stream>>>(
|
||||||
topk_ids.data_ptr<scalar_t>(),
|
topk_ids.data_ptr<scalar_t>(),
|
||||||
sorted_token_ids.data_ptr<int32_t>(),
|
sorted_token_ids.data_ptr<int32_t>(),
|
||||||
|
|||||||
Reference in New Issue
Block a user