From febe21ce031d6a32fbaadf044c4653f31fcaddad Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 4 Apr 2025 15:24:18 +0800 Subject: [PATCH] Small refactor DeepEPDispatcher into subclasses (#4994) --- .../srt/layers/moe/ep_moe/token_dispatcher.py | 465 +++++++++++------- python/sglang/srt/models/deepseek_v2.py | 47 +- 2 files changed, 315 insertions(+), 197 deletions(-) diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 2a2909816..a2bbbf3bf 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -23,7 +23,7 @@ _buffer_normal = None _buffer_low_latency = None -def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): +def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): """ Copy from DeepEP example usage in model inference prefilling. https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling @@ -53,7 +53,7 @@ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): return _buffer_normal -def get_buffer_low_latency( +def _get_buffer_low_latency( group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, hidden: int, @@ -85,24 +85,16 @@ def get_buffer_low_latency( return _buffer_low_latency -class DeepEPDispatcher: - """ - Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher - https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py - """ - +class _DeepEPDispatcherImplBase: def __init__( self, group: torch.distributed.ProcessGroup, router_topk: int, - permute_fusion: bool = False, - num_experts: int = None, - num_local_experts: int = None, - hidden_size: int = None, - params_dtype: torch.dtype = None, - deepep_mode: DeepEPMode = DeepEPMode.auto, - async_finish: bool = False, - return_recv_hook: bool = False, + permute_fusion: bool, + num_experts: int, + num_local_experts: int, + hidden_size: int, + params_dtype: torch.dtype, ): if not use_deepep: raise ImportError( @@ -119,115 +111,71 @@ class DeepEPDispatcher: self.params_dtype = params_dtype self.params_bytes = 2 - self.deepep_mode = deepep_mode self.handle = None - if self.deepep_mode.enable_normal(): - self.buffer_normal = get_buffer_normal( - self.group, self.hidden_size * self.params_bytes - ) - self.async_finish = async_finish - self.src2dst = None - if self.deepep_mode.enable_low_latency(): - """ - num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 - https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding - """ - # TODO(ch-wan): allow users to set this value - self.num_max_dispatch_tokens_per_rank = 128 - self.buffer_low_latency = get_buffer_low_latency( - self.group, - self.num_max_dispatch_tokens_per_rank, - self.hidden_size, - self.num_experts, - ) - self.return_recv_hook = return_recv_hook - - def deepep_permute( - self, - hidden_states: torch.Tensor, - topk_idx: torch.Tensor, - fp8_dtype: Optional[torch.dtype] = None, - use_fp8_w8a8: bool = False, - use_block_quant: bool = False, - ): - reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( - topk_idx, self.num_experts - ) - num_total_tokens = reorder_topk_ids.numel() - gateup_input = torch.empty( - (int(num_total_tokens), hidden_states.shape[1]), - device=hidden_states.device, - dtype=( - fp8_dtype - if (use_fp8_w8a8 and not use_block_quant) - else hidden_states.dtype - ), - ) - # PreReorder - deepep_permute_triton_kernel[(hidden_states.shape[0],)]( - hidden_states, - gateup_input, - self.src2dst, - topk_idx, - None, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - return reorder_topk_ids, seg_indptr, gateup_input - def dispatch( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, num_experts: int, - num_max_dispatch_tokens_per_rank: int = 128, - forward_mode: ForwardMode = None, - ) -> Tuple: - topk_idx = topk_idx.to(torch.int64) - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - ) - masked_m = torch.empty( - (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 - ) - expected_m = 0 + num_max_dispatch_tokens_per_rank: int, + ): + raise NotImplementedError - resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) - if resolved_deepep_mode == DeepEPMode.normal: - ( - hidden_states, - topk_idx, - topk_weights, - event, - ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) - event.current_stream_wait() if self.async_finish else () - if hidden_states.shape[0] > 0: - reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( - hidden_states, topk_idx, fp8_dtype=hidden_states.dtype - ) - elif resolved_deepep_mode == DeepEPMode.low_latency: - expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts - hidden_states, masked_m, event, hook = self.dispatch_low_latency( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, - use_fp8=True, + def combine( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: + raise NotImplementedError + + +class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): + def __init__(self, async_finish: bool, **kwargs): + super().__init__(**kwargs) + + self.buffer_normal = _get_buffer_normal( + self.group, self.hidden_size * self.params_bytes + ) + self.async_finish = async_finish + self.src2dst = None + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + num_max_dispatch_tokens_per_rank: int, + ): + topk_idx = topk_idx.to(torch.int64) + ( + hidden_states, + topk_idx, + topk_weights, + event, + ) = self._dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) + event.current_stream_wait() if self.async_finish else () + if hidden_states.shape[0] > 0: + reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( + hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) - hook() if self.return_recv_hook else event.current_stream_wait() else: - raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + ) + + # TODO + # masked_m = torch.empty( + # (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 + # ) + # expected_m = 0 + masked_m = expected_m = None return ( hidden_states, @@ -239,7 +187,7 @@ class DeepEPDispatcher: expected_m, ) - def dispatch_normal( + def _dispatch_normal( self, x: torch.Tensor, topk_idx: torch.Tensor, @@ -292,7 +240,156 @@ class DeepEPDispatcher: event, ) - def dispatch_low_latency( + def _deepep_permute( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + fp8_dtype: Optional[torch.dtype] = None, + use_fp8_w8a8: bool = False, + use_block_quant: bool = False, + ): + """ + Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py + """ + + reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( + topk_idx, self.num_experts + ) + num_total_tokens = reorder_topk_ids.numel() + gateup_input = torch.empty( + (int(num_total_tokens), hidden_states.shape[1]), + device=hidden_states.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_block_quant) + else hidden_states.dtype + ), + ) + # PreReorder + deepep_permute_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + self.src2dst, + topk_idx, + None, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + return reorder_topk_ids, seg_indptr, gateup_input + + def combine( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + ) -> torch.Tensor: + if hidden_states.shape[0] > 0: + num_tokens = self.src2dst.shape[0] // self.router_topk + output = torch.empty( + (num_tokens, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + deepep_post_reorder_triton_kernel[(num_tokens,)]( + hidden_states, + output, + self.src2dst, + topk_idx, + topk_weights, + self.router_topk, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + else: + output = torch.zeros( + (0, hidden_states.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + hidden_states, event = self._combine_normal( + output, + ) + event.current_stream_wait() if self.async_finish else () + + return hidden_states + + def _combine_normal(self, x: torch.Tensor): + previous_event = Buffer.capture() if self.async_finish else None + + combined_x, _, event = self.buffer_normal.combine( + x, + self.handle, + async_finish=self.async_finish, + previous_event=previous_event, + allocate_on_comm_stream=previous_event is not None, + ) + return combined_x, event + + +class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): + def __init__(self, return_recv_hook: bool, **kwargs): + super().__init__(**kwargs) + + """ + num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 + https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding + """ + # TODO(ch-wan): allow users to set this value + self.num_max_dispatch_tokens_per_rank = 128 + self.buffer_low_latency = _get_buffer_low_latency( + self.group, + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.num_experts, + ) + self.return_recv_hook = return_recv_hook + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + num_max_dispatch_tokens_per_rank: int, + ): + topk_idx = topk_idx.to(torch.int64) + expected_m = ( + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts + hidden_states, masked_m, event, hook = self._dispatch_low_latency( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + use_fp8=True, + ) + hook() if self.return_recv_hook else event.current_stream_wait() + + # TODO + # reorder_topk_ids = torch.empty( + # (0,), device=hidden_states.device, dtype=torch.int64 + # ) + # seg_indptr = torch.zeros( + # (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + # ) + reorder_topk_ids = seg_indptr = None + + return ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + ) + + def _dispatch_low_latency( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -351,62 +448,17 @@ class DeepEPDispatcher: hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - forward_mode: ForwardMode, ) -> torch.Tensor: - resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) - if resolved_deepep_mode == DeepEPMode.normal: - if hidden_states.shape[0] > 0: - num_tokens = self.src2dst.shape[0] // self.router_topk - output = torch.empty( - (num_tokens, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - deepep_post_reorder_triton_kernel[(num_tokens,)]( - hidden_states, - output, - self.src2dst, - topk_idx, - topk_weights, - self.router_topk, - hidden_states.shape[1], - BLOCK_SIZE=512, - ) - else: - output = torch.zeros( - (0, hidden_states.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - hidden_states, event = self.combine_normal( - output, - ) - event.current_stream_wait() if self.async_finish else () - elif resolved_deepep_mode == DeepEPMode.low_latency: - hidden_states, event, hook = self.combine_low_latency( - hidden_states, - topk_idx, - topk_weights, - ) - hook() if self.return_recv_hook else event.current_stream_wait() - else: - raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + hidden_states, event, hook = self._combine_low_latency( + hidden_states, + topk_idx, + topk_weights, + ) + hook() if self.return_recv_hook else event.current_stream_wait() return hidden_states - def combine_normal(self, x: torch.Tensor): - previous_event = Buffer.capture() if self.async_finish else None - - combined_x, _, event = self.buffer_normal.combine( - x, - self.handle, - async_finish=self.async_finish, - previous_event=previous_event, - allocate_on_comm_stream=previous_event is not None, - ) - return combined_x, event - - def combine_low_latency( + def _combine_low_latency( self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, @@ -423,3 +475,80 @@ class DeepEPDispatcher: ) ) return combined_hidden_states, event, hook + + +class DeepEPDispatcher: + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + deepep_mode: DeepEPMode = DeepEPMode.auto, + async_finish: bool = False, + return_recv_hook: bool = False, + ): + self.deepep_mode = deepep_mode + + common_kwargs = dict( + group=group, + router_topk=router_topk, + permute_fusion=permute_fusion, + num_experts=num_experts, + num_local_experts=num_local_experts, + hidden_size=hidden_size, + params_dtype=params_dtype, + ) + + if self.deepep_mode.enable_normal(): + self._normal_dispatcher = _DeepEPDispatcherImplNormal( + async_finish=async_finish, + **common_kwargs, + ) + if self.deepep_mode.enable_low_latency(): + self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency( + return_recv_hook=return_recv_hook, + **common_kwargs, + ) + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + num_max_dispatch_tokens_per_rank: int = 128, + forward_mode: ForwardMode = None, + ) -> Tuple: + return self._get_dispatcher(forward_mode).dispatch( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_experts=num_experts, + num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, + ) + + def combine( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode, + ) -> torch.Tensor: + return self._get_dispatcher(forward_mode).combine( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + ) + + def _get_dispatcher(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase: + resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) + if resolved_deepep_mode == DeepEPMode.normal: + return self._normal_dispatcher + elif resolved_deepep_mode == DeepEPMode.low_latency: + return self._low_latency_dispatcher + else: + raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index adbc76b9e..2f78de492 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -188,35 +188,24 @@ class DeepseekV2MoE(nn.Module): if global_server_args_dict["enable_deepep_moe"] else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) ) - if not global_server_args_dict["enable_deepep_moe"]: - self.experts = MoEImpl( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - prefix=add_prefix("experts", prefix), - ) - else: - self.experts = MoEImpl( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - prefix=add_prefix("experts", prefix), - deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], - ) + self.experts = MoEImpl( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + prefix=add_prefix("experts", prefix), + **( + dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) + if global_server_args_dict["enable_deepep_moe"] + else {} + ), + ) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts