diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 6d6c432f8..30c9eb6a7 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -17,52 +17,6 @@ if _is_cuda: logger = logging.getLogger(__name__) -@triton.jit -def compute_src2dst_triton_kernel( - reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(axis=0) - dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = dst_id < num_toks - src_id = tl.load(reorder_ids + dst_id, mask=mask) - tl.store(src2dst + src_id, dst_id, mask=mask) - - -@triton.jit -def deepep_compute_src2dst_triton_kernel( - reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(axis=0) - dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = dst_id < num_toks - src_id = tl.load(reorder_ids + dst_id, mask=mask) - num_invalid = tl.load(num_minus_one) - tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask) - - -def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): - reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) - seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) - src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) - - # Find offet - expert_ids = torch.arange( - num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype - ) - torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr) - num_minus_one = seg_indptr[0] - seg_indptr = seg_indptr - num_minus_one - - BLOCK_SIZE = 512 - grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) - deepep_compute_src2dst_triton_kernel[grid]( - reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE - ) - - reorder_topk_ids = reorder_topk_ids[num_minus_one:] - return reorder_topk_ids, src2dst, seg_indptr - - @triton.jit def deepep_permute_triton_kernel( input_ptr, @@ -85,14 +39,13 @@ def deepep_permute_triton_kernel( for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): offset = start_offset + tl.arange(0, BLOCK_SIZE) mask = offset < hidden_size - in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype) for idx in range(topk): dst_idx = tl.load(src2dst_ptr + idx) if dst_idx >= 0: dst_ptr = gateup_input_ptr + dst_idx * hidden_size - out_data = (in_data).to(OutDtype) - tl.store(dst_ptr + offset, out_data, mask=mask) + tl.store(dst_ptr + offset, in_data, mask=mask) @triton.jit @@ -128,6 +81,51 @@ def deepep_post_reorder_triton_kernel( tl.store(store_ptr + offset, sum_vec, mask=mask) +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +@triton.jit +def deepep_compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + num_invalid = tl.load(num_minus_one) + tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask) + + +def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): + reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64) + + # Find offet + expert_ids = torch.arange( + num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype + ) + torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr) + num_minus_one = seg_indptr[0] + seg_indptr = seg_indptr - num_minus_one + + BLOCK_SIZE = 512 + grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + deepep_compute_src2dst_triton_kernel[grid]( + reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE + ) + reorder_topk_ids = reorder_topk_ids[num_minus_one:] + return reorder_topk_ids, src2dst, seg_indptr + + @triton.jit def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): expert = tl.program_id(0) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index a9b443a75..f0595bfb1 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -831,19 +831,23 @@ class DeepEPMoE(EPMoE): def forward( self, hidden_states: torch.Tensor, - tokens_per_expert: torch.Tensor, + reorder_topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, forward_mode: ForwardMode, ): # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode) if True: # not forward_mode.is_decode(): - return self.forward_normal(hidden_states, tokens_per_expert) + return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) else: - return self.forward_deepgemm_masked(hidden_states, tokens_per_expert) + return self.forward_deepgemm_masked( + hidden_states, reorder_topk_ids, seg_indptr + ) def forward_normal( self, hidden_states: torch.Tensor, - tokens_per_expert: torch.Tensor, + reorder_topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, ): assert self.quant_method is not None assert self.activation == "silu" @@ -851,15 +855,7 @@ class DeepEPMoE(EPMoE): self.grouped_gemm_runner = GroupedGemmRunner( hidden_states.device, use_flashinfer=False # TODO: use flashinfer ) - seg_indptr_cur_rank = torch.cat( - [ - torch.zeros( - 1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype - ), - torch.cumsum(tokens_per_expert, dim=0), - ] - ) - reorder_topk_ids = torch.repeat_interleave(tokens_per_expert) + if self.activation_scheme == "dynamic" and not self.use_block_quant: max_value = ( torch.max(hidden_states) @@ -881,6 +877,7 @@ class DeepEPMoE(EPMoE): device=hidden_states.device, dtype=hidden_states.dtype, ) + if hidden_states.shape[0] > 0: gateup_output = self.grouped_gemm_runner( a=hidden_states, @@ -888,7 +885,7 @@ class DeepEPMoE(EPMoE): c=gateup_output, batch_size=self.num_experts_per_partition, weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, + seg_indptr=seg_indptr, weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w13_input_scale, @@ -946,7 +943,7 @@ class DeepEPMoE(EPMoE): c=down_output, batch_size=self.num_experts_per_partition, weight_column_major=True, - seg_indptr=seg_indptr_cur_rank, + seg_indptr=seg_indptr, weight_indices=weight_indices_cur_rank, use_fp8_w8a8=self.use_fp8_w8a8, scale_a=self.w2_input_scale, diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index c91ccd633..6d8605f77 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -12,7 +12,6 @@ import torch import torch.distributed as dist from sglang.srt.layers.moe.ep_moe.kernels import ( - compute_src2dst_triton_kernel, deepep_permute_triton_kernel, deepep_post_reorder_triton_kernel, deepep_run_moe_deep_preprocess, @@ -86,90 +85,6 @@ def get_buffer_low_latency( return _buffer_low_latency -def permute( - tokens, - routing_map, - num_out_tokens: Optional[int] = None, - fused: bool = False, - drop_and_pad: bool = False, -): - """ - Copy from Megatron-Core moe for token permutation - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py - """ - - num_tokens, _ = tokens.shape - num_experts = routing_map.shape[1] - if drop_and_pad and not (num_out_tokens is None): - capacity = num_out_tokens // num_experts - assert not routing_map.requires_grad - routing_map = routing_map.to(dtype=torch.int8).T.contiguous() - sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[ - :, :capacity - ].contiguous() - sorted_indices = sorted_indices.view(-1) - else: - routing_map = routing_map.bool().T.contiguous() - token_indices = ( - torch.arange(num_tokens, device=routing_map.device) - .unsqueeze(0) - .expand(num_experts, -1) - ) - sorted_indices = token_indices.masked_select(routing_map) - permuted_input = tokens.index_select(0, sorted_indices) - - return permuted_input, sorted_indices - - -def unpermute( - permuted_tokens: torch.Tensor, - sorted_indices: torch.Tensor, - restore_shape: torch.Size, - probs: torch.Tensor = None, - routing_map: torch.Tensor = None, - fused: bool = False, - drop_and_pad: bool = False, -): - """ - Copy from Megatron-Core moe for token unpermutation - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py - """ - - _, hidden = restore_shape - - if probs is not None: - assert routing_map is not None, "Mask must be provided to permute the probs." - if drop_and_pad: - num_experts = routing_map.size(1) - num_permuted_tokens = sorted_indices.size(0) - capacity = num_permuted_tokens // num_experts - num_unpermuted_tokens = probs.size(0) - - probs_T_1D = probs.T.contiguous().view(-1) - - indices_dim0 = torch.arange( - num_experts, device=routing_map.device - ).unsqueeze(-1) - indices_dim1 = sorted_indices.view(num_experts, capacity) - indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1) - - permuted_probs = probs_T_1D.index_select(0, indices_1D) - else: - permuted_probs = probs.T.contiguous().masked_select( - routing_map.T.contiguous() - ) - permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) - - output_tokens = torch.zeros( - restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype - ) - output_tokens.scatter_add_( - 0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens - ) - - return output_tokens - - class DeepEPDispatcher: """ Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher @@ -228,16 +143,13 @@ class DeepEPDispatcher: def deepep_permute( self, - topk_ids, hidden_states, - num_experts, - top_k, - use_fp8_w8a8, - use_block_quant, - fp8_dtype, + fp8_dtype=None, + use_fp8_w8a8=False, + use_block_quant=False, ): reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess( - topk_ids, num_experts + self.topk_idx, self.num_experts ) num_total_tokens = reorder_topk_ids.numel() gateup_input = torch.empty( @@ -254,9 +166,9 @@ class DeepEPDispatcher: hidden_states, gateup_input, src2dst, - topk_ids, + self.topk_idx, None, - top_k, + self.router_topk, hidden_states.shape[1], BLOCK_SIZE=512, ) @@ -302,13 +214,21 @@ class DeepEPDispatcher: ) ) self.recv_expert_count = recv_expert_count - tokens_per_expert = self.get_number_of_tokens_per_expert() self.handle = handle self.topk_idx = topk_idx self.topk_weights = topk_weights if hidden_states.shape[0] > 0: - hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states) - return hidden_states, topk_idx, topk_weights, tokens_per_expert + reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( + hidden_states, fp8_dtype=hidden_states.dtype + ) + else: + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + ) + return hidden_states, reorder_topk_ids, seg_indptr def dispatch_normal( self, @@ -427,10 +347,29 @@ class DeepEPDispatcher: # Todo: enable low latency combine if True: # not forward_mode.is_decode(): if hidden_states.shape[0] > 0: - hidden_states = self.get_restored_hidden_states_by_experts( - hidden_states + num_tokens = self.src2dst.shape[0] // self.router_topk + output = torch.empty( + (num_tokens, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, ) - hidden_states, event = self.combine_normal(hidden_states, self.handle) + deepep_post_reorder_triton_kernel[(num_tokens,)]( + hidden_states, + output, + self.src2dst, + self.topk_idx, + self.topk_weights, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + else: + output = torch.zeros( + (0, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + hidden_states, event = self.combine_normal(output, self.handle) else: hidden_states, event, hook = self.combine_low_latency( hidden_states, self.topk_idx, self.topk_weights, self.handle @@ -467,67 +406,3 @@ class DeepEPDispatcher: ) # hook() return combined_hidden_states, event_overlap, hook - - def _indices_to_multihot(self, indices, probs): - batch_size = indices.shape[0] - multihot_routing_map = torch.zeros( - (batch_size, self.num_local_experts), - dtype=torch.long, - device=indices.device, - ) - - multihot_probs = torch.zeros( - (batch_size, self.num_local_experts), - dtype=torch.float, - device=indices.device, - ) - - mask = indices != -1 - valid_indices = indices[mask] - row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave( - mask.sum(dim=1) - ) - multihot_routing_map[row_indices, valid_indices] = 1 - multihot_probs[row_indices, valid_indices] = probs[mask] - return multihot_routing_map.bool(), multihot_probs - - def get_dispached_metadata(self) -> torch.Tensor: - return self.topk_idx, self.topk_weights - - def get_number_of_tokens_per_expert(self) -> torch.Tensor: - """ - Get the number of tokens per expert. - """ - return self.tokens_per_expert - - def get_permuted_hidden_states_by_experts( - self, hidden_states: torch.Tensor - ) -> torch.Tensor: - self.dispatched_routing_map, self.topk_weights = self._indices_to_multihot( - self.topk_idx, self.topk_weights - ) - self.hidden_shape_before_permute = hidden_states.shape - hidden_states, self.reversed_mapping_for_combine = permute( - hidden_states, - self.dispatched_routing_map, - num_out_tokens=self.tokens_per_expert.sum(), - fused=self.permute_fusion, - ) - return hidden_states - - def get_restored_hidden_states_by_experts( - self, hidden_states: torch.Tensor - ) -> torch.Tensor: - input_dtype = hidden_states.dtype - assert ( - self.topk_weights.dtype == torch.float32 - ), "DeepEP only supports float32 probs" - hidden_states = unpermute( - hidden_states, - self.reversed_mapping_for_combine, - restore_shape=self.hidden_shape_before_permute, - routing_map=self.dispatched_routing_map, - probs=self.topk_weights, - fused=self.permute_fusion, - ) - return hidden_states.to(input_dtype) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index ffcc9a955..c62dacec9 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -294,7 +294,7 @@ class DeepseekV2MoE(nn.Module): correction_bias=self.correction_bias, ) if self.tp_size > 1: - recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = ( + recv_hidden_states, reorder_topk_ids, seg_indptr = ( self.deepep_dispatcher.dispatch( hidden_states, topk_idx, @@ -306,7 +306,8 @@ class DeepseekV2MoE(nn.Module): final_hidden_states = ( self.experts( hidden_states=recv_hidden_states, - tokens_per_expert=tokens_per_expert, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, forward_mode=forward_mode, ) * self.routed_scaling_factor