From ba5112ff691d791a9e38c6c71f59324a5fcb49d0 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 2 Jan 2025 21:47:44 +0800 Subject: [PATCH] feat: support moe_align_block_size_triton (#2712) Co-authored-by: WANDY666 <1060304770@qq.com> --- .../layers/moe/fused_moe_triton/fused_moe.py | 202 +++++++++++++++--- 1 file changed, 177 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 2a8080dd3..128881e8a 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -17,15 +17,21 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip -not_hip = False +is_hip_flag = False if not is_hip(): from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - not_hip = True + is_hip_flag = False +else: + is_hip_flag = True logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 +enable_moe_align_block_size_triton = bool( + int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) +) + @triton.jit def fused_moe_kernel( @@ -222,6 +228,139 @@ def fused_moe_kernel( tl.store(c_ptrs, accumulator, mask=c_mask) +def ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def moe_align_block_size_stage1( + topk_ids_ptr, + tokens_cnts_ptr, + num_experts: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + + start_idx = pid * tokens_per_thread + + off_c = (pid + 1) * num_experts + + for i in range(tokens_per_thread): + if start_idx + i < numel: + idx = tl.load(topk_ids_ptr + start_idx + i) + token_cnt = tl.load(tokens_cnts_ptr + off_c + idx) + tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1) + + +@triton.jit +def moe_align_block_size_stage2( + tokens_cnts_ptr, + num_experts: tl.constexpr, +): + pid = tl.program_id(0) + + last_cnt = 0 + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid) + last_cnt = last_cnt + token_cnt + tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt) + + +@triton.jit +def moe_align_block_size_stage3( + total_tokens_post_pad_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, +): + last_cumsum = 0 + off_cnt = num_experts * num_experts + for i in range(1, num_experts + 1): + token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1) + last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size + tl.store(cumsum_ptr + i, last_cumsum) + tl.store(total_tokens_post_pad_ptr, last_cumsum) + + +@triton.jit +def moe_align_block_size_stage4( + topk_ids_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + tokens_cnts_ptr, + cumsum_ptr, + num_experts: tl.constexpr, + block_size: tl.constexpr, + numel: tl.constexpr, + tokens_per_thread: tl.constexpr, +): + pid = tl.program_id(0) + start_idx = tl.load(cumsum_ptr + pid) + end_idx = tl.load(cumsum_ptr + pid + 1) + + for i in range(start_idx, end_idx, block_size): + tl.store(expert_ids_ptr + i // block_size, pid) + + start_idx = pid * tokens_per_thread + off_t = pid * num_experts + + for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)): + expert_id = tl.load(topk_ids_ptr + i) + token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id) + rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id) + tl.store(sorted_token_ids_ptr + rank_post_pad, i) + tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1) + + +def moe_align_block_size_triton( + topk_ids: torch.Tensor, + num_experts: int, + block_size: int, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor, +) -> None: + numel = topk_ids.numel() + grid = (num_experts,) + tokens_cnts = torch.zeros( + (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device + ) + cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device) + tokens_per_thread = ceil_div(numel, num_experts) + + moe_align_block_size_stage1[grid]( + topk_ids, + tokens_cnts, + num_experts, + numel, + tokens_per_thread, + ) + moe_align_block_size_stage2[grid]( + tokens_cnts, + num_experts, + ) + moe_align_block_size_stage3[(1,)]( + num_tokens_post_pad, + tokens_cnts, + cumsum, + num_experts, + block_size, + ) + moe_align_block_size_stage4[grid]( + topk_ids, + sorted_token_ids, + expert_ids, + tokens_cnts, + cumsum, + num_experts, + block_size, + numel, + tokens_per_thread, + ) + + def moe_align_block_size( topk_ids: torch.Tensor, block_size: int, num_experts: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -272,24 +411,36 @@ def moe_align_block_size( (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - if not_hip and num_experts >= 224: - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device - ) - cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device - ) + if num_experts >= 224: + if enable_moe_align_block_size_triton or is_hip_flag: + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + ) + else: + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) - sgl_moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer, - cumsum_buffer, - ) + sgl_moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + ) else: ops.moe_align_block_size( topk_ids, @@ -854,17 +1005,18 @@ def fused_experts_impl( block_shape=block_shape, ) - if not_hip: + if is_hip_flag: + ops.moe_sum( + intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx], + ) + else: torch.sum( intermediate_cache3.view(*intermediate_cache3.shape), dim=1, out=out_hidden_states[begin_chunk_idx:end_chunk_idx], ) - else: - ops.moe_sum( - intermediate_cache3.view(*intermediate_cache3.shape), - out_hidden_states[begin_chunk_idx:end_chunk_idx], - ) + return out_hidden_states