diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 100fa57fb..b10b1c98b 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -7,6 +7,7 @@ try: except ImportError: use_deepep = False +from enum import IntEnum, auto from typing import Optional, Tuple import torch @@ -19,70 +20,95 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ) from sglang.srt.model_executor.forward_batch_info import ForwardMode -_buffer_normal = None -_buffer_low_latency = None + +class DeepEPDispatchMode(IntEnum): + NORMAL = auto() + LOW_LATENCY = auto() -def _get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): - """ - Copy from DeepEP example usage in model inference prefilling. - https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling - """ +class DeepEPBuffer: - global _buffer_normal + _buffer: Optional[Buffer] = None + _dispatch_mode: Optional[DeepEPDispatchMode] = None + _hidden_size: Optional[int] = None + _num_max_dispatch_tokens_per_rank: Optional[int] = None + _num_experts: Optional[int] = None - num_nvl_bytes, num_rdma_bytes = 0, 0 - for config in ( - Buffer.get_dispatch_config(group.size()), - Buffer.get_combine_config(group.size()), + @classmethod + def get_deepep_buffer( + cls, + group: dist.ProcessGroup, + hidden_size: int, + param_bytes: int, + deepep_mode: DeepEPMode, + num_max_dispatch_tokens_per_rank: int = None, + num_experts: int = None, ): - num_nvl_bytes = max( - config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes - ) - num_rdma_bytes = max( - config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes - ) + if cls._buffer is not None: + return cls._buffer - if ( - _buffer_normal is None - or _buffer_normal.group != group - or _buffer_normal.num_nvl_bytes < num_nvl_bytes - or _buffer_normal.num_rdma_bytes < num_rdma_bytes - ): - _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes) - return _buffer_normal + cls._hidden_size = hidden_size + cls._num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + cls._num_experts = num_experts + num_nvl_bytes, num_rdma_bytes = 0, 0 + if deepep_mode.enable_normal(): + hidden_bytes = hidden_size * param_bytes + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max( + config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), + num_nvl_bytes, + ) + num_rdma_bytes = max( + config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), + num_rdma_bytes, + ) + if deepep_mode.enable_low_latency(): + assert num_max_dispatch_tokens_per_rank is not None + assert num_experts is not None and num_experts % group.size() == 0 + num_rdma_bytes = max( + Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank, + hidden_size, + group.size(), + num_experts, + ), + num_rdma_bytes, + ) -def _get_buffer_low_latency( - group: dist.ProcessGroup, - num_max_dispatch_tokens_per_rank: int, - hidden: int, - num_experts: int, -): - """ - Copy from DeepEP example usage in model inference decoding. - https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding - """ - - global _buffer_low_latency - num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint( - num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts - ) - - if ( - _buffer_low_latency is None - or _buffer_low_latency.group != group - or not _buffer_low_latency.low_latency_mode - or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes - ): - assert num_experts % group.size() == 0 - _buffer_low_latency = Buffer( + cls._buffer = Buffer( group, - num_rdma_bytes=num_rdma_bytes, - low_latency_mode=True, - num_qps_per_rank=num_experts // group.size(), + num_nvl_bytes, + num_rdma_bytes, + low_latency_mode=deepep_mode.enable_low_latency(), + num_qps_per_rank=( + num_experts // group.size() if deepep_mode.enable_low_latency() else 1 + ), ) - return _buffer_low_latency + return cls._buffer + + @classmethod + def clean_buffer(cls): + if not cls._buffer.low_latency_mode: + return + cls._buffer.clean_low_latency_buffer( + cls._num_max_dispatch_tokens_per_rank, + cls._hidden_size, + cls._num_experts, + ) + + @classmethod + def set_dispatch_mode_as_normal(cls): + cls._dispatch_mode = DeepEPDispatchMode.NORMAL + + @classmethod + def set_dispatch_mode_as_low_latency(cls): + if cls._dispatch_mode == DeepEPDispatchMode.NORMAL: + cls.clean_buffer() + cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY class _DeepEPDispatcherImplBase: @@ -95,6 +121,7 @@ class _DeepEPDispatcherImplBase: num_local_experts: int, hidden_size: int, params_dtype: torch.dtype, + deepep_mode: DeepEPMode, ): if not use_deepep: raise ImportError( @@ -109,7 +136,10 @@ class _DeepEPDispatcherImplBase: self.num_local_experts = num_local_experts self.hidden_size = hidden_size self.params_dtype = params_dtype + self.deepep_mode = deepep_mode + self.params_bytes = 2 + self.num_max_dispatch_tokens_per_rank = 128 self.handle = None @@ -118,8 +148,6 @@ class _DeepEPDispatcherImplBase: hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - num_experts: int, - num_max_dispatch_tokens_per_rank: int, ): raise NotImplementedError @@ -137,14 +165,14 @@ class _DeepEPDispatcherImplBase: def combine_b(self, *args, **kwargs): raise NotImplementedError + def _get_buffer(self) -> Buffer: + raise NotImplementedError + class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): def __init__(self, async_finish: bool, **kwargs): super().__init__(**kwargs) - self.buffer_normal = _get_buffer_normal( - self.group, self.hidden_size * self.params_bytes - ) self.async_finish = async_finish self.src2dst = None @@ -153,24 +181,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - num_experts: int, - num_max_dispatch_tokens_per_rank: int, ): topk_idx = topk_idx.to(torch.int64) previous_event = Buffer.capture() if self.async_finish else None - return hidden_states, topk_idx, topk_weights, num_experts, previous_event + return hidden_states, topk_idx, topk_weights, previous_event - def dispatch_b( - self, hidden_states, topk_idx, topk_weights, num_experts, previous_event - ): + def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event): ( hidden_states, topk_idx, topk_weights, event, - ) = self._dispatch_core( - hidden_states, topk_idx, topk_weights, num_experts, previous_event - ) + ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event) event.current_stream_wait() if self.async_finish else () if hidden_states.shape[0] > 0: reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute( @@ -181,7 +203,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): (0,), device=hidden_states.device, dtype=torch.int64 ) seg_indptr = torch.zeros( - (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64 ) masked_m = expected_m = None @@ -201,18 +223,18 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): x: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - num_experts: int, previous_event, ): + buffer = self._get_buffer() ( num_tokens_per_rank, num_tokens_per_rdma_rank, num_tokens_per_expert, is_token_in_rank, previous_event, - ) = self.buffer_normal.get_dispatch_layout( + ) = buffer.get_dispatch_layout( topk_idx, - num_experts, + self.num_experts, previous_event=previous_event, async_finish=self.async_finish, allocate_on_comm_stream=previous_event is not None, @@ -221,6 +243,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): # FIXME: `handle` should be transmitted with tokens from dispatch to combine. # However, doing this would incur an unknown synchronization error, but keeping # `handle` as a member variable works. + ( recv_x, recv_topk_idx, @@ -228,7 +251,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): _, # num_recv_tokens_per_expert_list self.handle, event, - ) = self.buffer_normal.dispatch( + ) = buffer.dispatch( x, topk_idx=topk_idx, topk_weights=topk_weights, @@ -327,7 +350,8 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): return hidden_states def _combine_core(self, x: torch.Tensor, previous_event): - combined_x, _, event = self.buffer_normal.combine( + buffer = self._get_buffer() + combined_x, _, event = buffer.combine( x, self.handle, async_finish=self.async_finish, @@ -336,6 +360,17 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase): ) return combined_x, event + def _get_buffer(self): + DeepEPBuffer.set_dispatch_mode_as_normal() + return DeepEPBuffer.get_deepep_buffer( + self.group, + self.hidden_size, + self.params_bytes, + self.deepep_mode, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, + ) + class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): def __init__(self, return_recv_hook: bool, **kwargs): @@ -345,14 +380,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding """ - # TODO(ch-wan): allow users to set this value - self.num_max_dispatch_tokens_per_rank = 128 - self.buffer_low_latency = _get_buffer_low_latency( - self.group, - self.num_max_dispatch_tokens_per_rank, - self.hidden_size, - self.num_experts, - ) self.return_recv_hook = return_recv_hook def dispatch_a( @@ -360,21 +387,16 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - num_experts: int, - num_max_dispatch_tokens_per_rank: int, ): + buffer = self._get_buffer() topk_idx = topk_idx.to(torch.int64) expected_m = ( - hidden_states.shape[0] - * self.buffer_low_latency.group_size - * topk_idx.shape[1] - + num_experts - ) // num_experts + hidden_states.shape[0] * buffer.group_size * topk_idx.shape[1] + + self.num_experts + ) // self.num_experts hidden_states, masked_m, event, hook = self._dispatch_core( hidden_states, topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, use_fp8=True, ) return ( @@ -415,8 +437,6 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): self, hidden_states: torch.Tensor, topk_idx: torch.Tensor, - num_max_dispatch_tokens_per_rank: int, - num_experts: int, use_fp8: bool = False, ): """ @@ -451,13 +471,13 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; """ - + buffer = self._get_buffer() packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( - self.buffer_low_latency.low_latency_dispatch( + buffer.low_latency_dispatch( hidden_states, topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, use_fp8=use_fp8, async_finish=not self.return_recv_hook, return_recv_hook=self.return_recv_hook, @@ -488,19 +508,29 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): topk_idx: torch.Tensor, topk_weights: torch.Tensor, ): - combined_hidden_states, event, hook = ( - self.buffer_low_latency.low_latency_combine( - hidden_states, - topk_idx, - topk_weights, - self.handle, - async_finish=not self.return_recv_hook, - return_recv_hook=self.return_recv_hook, - ) + buffer = self._get_buffer() + combined_hidden_states, event, hook = buffer.low_latency_combine( + hidden_states, + topk_idx, + topk_weights, + self.handle, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, ) self.handle = None return combined_hidden_states, event, hook + def _get_buffer(self): + DeepEPBuffer.set_dispatch_mode_as_low_latency() + return DeepEPBuffer.get_deepep_buffer( + self.group, + self.hidden_size, + self.params_bytes, + self.deepep_mode, + self.num_max_dispatch_tokens_per_rank, + self.num_experts, + ) + class DeepEPDispatcher: def __init__( @@ -526,18 +556,19 @@ class DeepEPDispatcher: num_local_experts=num_local_experts, hidden_size=hidden_size, params_dtype=params_dtype, + deepep_mode=deepep_mode, ) - if self.deepep_mode.enable_normal(): - self._normal_dispatcher = _DeepEPDispatcherImplNormal( - async_finish=async_finish, - **common_kwargs, - ) if self.deepep_mode.enable_low_latency(): self._low_latency_dispatcher = _DeepEPDispatcherImplLowLatency( return_recv_hook=return_recv_hook, **common_kwargs, ) + if self.deepep_mode.enable_normal(): + self._normal_dispatcher = _DeepEPDispatcherImplNormal( + async_finish=async_finish, + **common_kwargs, + ) def dispatch(self, *args, **kwargs) -> Tuple: self.dispatch_a(*args, **kwargs) @@ -548,16 +579,12 @@ class DeepEPDispatcher: hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - num_experts: int, - num_max_dispatch_tokens_per_rank: int = 128, forward_mode: ForwardMode = None, ): inner_state = self._get_impl(forward_mode).dispatch_a( hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights, - num_experts=num_experts, - num_max_dispatch_tokens_per_rank=num_max_dispatch_tokens_per_rank, ) self._dispatch_intermediate_state = forward_mode, inner_state @@ -589,7 +616,7 @@ class DeepEPDispatcher: del self._combine_intermediate_state return self._get_impl(forward_mode).combine_b(*inner_state) - def _get_impl(self, forward_mode: ForwardMode) -> "_DeepEPDispatcherImplBase": + def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase: resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) if resolved_deepep_mode == DeepEPMode.normal: return self._normal_dispatcher diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 96a13a999..ae3e3eb8c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -72,7 +72,7 @@ class ForwardMode(IntEnum): DUMMY_FIRST = auto() def is_prefill(self): - return self == ForwardMode.PREFILL + return self.is_extend() def is_extend(self): return ( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index d973f1b88..132210e27 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -324,6 +324,7 @@ class DeepseekV2MoE(nn.Module): correction_bias=self.correction_bias, ) if self.ep_size > 1: + # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value ( hidden_states, topk_idx, @@ -336,7 +337,6 @@ class DeepseekV2MoE(nn.Module): hidden_states, topk_idx, topk_weights, - self.num_experts, forward_mode=forward_mode, ) final_hidden_states = ( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a65c90de8..28539dcee 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1101,6 +1101,7 @@ class ServerArgs: "--deepep-mode", type=str, choices=["normal", "low_latency", "auto"], + default="auto", help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", )