diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 3d2aae8f2..3c96a6816 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -91,6 +91,7 @@ Please consult the documentation below to learn more about the parameters you ma * `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models. * `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`. * `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. +* `deepep_mode`: Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. ## Memory and scheduling diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index 30c9eb6a7..3ea6b4b2f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -244,6 +244,148 @@ def silu_and_mul_triton_kernel( tl.store(down_input_ptr + offset, silu_mul_output, mask=mask) +# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py +@triton.jit +def _silu_and_mul_post_quant_kernel( + input_ptr, + stride_input_0, + stride_input_1, + stride_input_2, + output_ptr, + stride_output_0, + stride_output_1, + stride_output_2, + output_scale_ptr, + stride_output_scale_0, + stride_output_scale_1, + stride_output_scale_2, + masked_m_ptr, + size_n, + fp8_max, + fp8_min, + BLOCK_N: tl.constexpr, + NUM_STAGE: tl.constexpr, +): + expert_id = tl.program_id(2) + token_id = tl.program_id(1) + hidden_dim_block_index = tl.program_id(0) + + block_num_per_expert = tl.num_programs(1) + + token_num_cur_expert = tl.load(masked_m_ptr + expert_id) + + stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64) + stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64) + stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64) + stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64) + + offs_in_d = hidden_dim_block_index * BLOCK_N + tl.arange(0, BLOCK_N) + input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d + output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d + output_scale_offs = ( + output_scale_ptr + + expert_id * stride_output_scale_0 + + hidden_dim_block_index * stride_output_scale_2 + ) + + for token_index in tl.range( + token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE + ): + gate = tl.load( + input_ptr_offs + token_index * stride_input_1, + mask=offs_in_d < size_n, + other=0.0, + ).to(tl.float32) + up = tl.load( + input_ptr_offs + token_index * stride_input_1 + size_n, + mask=offs_in_d < size_n, + other=0.0, + ) + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + gate_up = up * gate + _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) + output_s = _absmax / fp8_max + output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to( + output_ptr.dtype.element_ty + ) + tl.store( + output_ptr_offs + token_index * stride_output_1, + output_q, + mask=offs_in_d < size_n, + ) + tl.store( + output_scale_offs + token_index * stride_output_scale_1, + output_s, + ) + + +def silu_and_mul_masked_post_quant_fwd( + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, +): + """ + input shape [expert_num, token_num_padded, hidden_dim] + output shape [expert_num, token_num_padded, hidden_dim // 2], dtype fp8 + output_scale [expert_num token_num_paddded, hidden_dim // 2 // 128] dtype float32 + quant_group_size int, + masked_m shape [expert_num], + """ + + assert input.is_contiguous() + assert output.dtype == torch.float8_e4m3fn + assert output.is_contiguous() + assert len(input.shape) == 3 + assert input.shape[0] == masked_m.shape[0] + assert input.shape[-1] % 2 == 0 + + size_n = input.shape[-1] // 2 + assert size_n % quant_group_size == 0 + + expert_num = len(masked_m) + + if expert_num < 4: + BLOCK_NUM_PER_EXPERT = 64 + else: + BLOCK_NUM_PER_EXPERT = 32 + + BLOCK_N = quant_group_size + num_warps = 1 + NUM_STAGES = 6 + hidden_dim_split_block_num = triton.cdiv(size_n, BLOCK_N) + assert BLOCK_N % quant_group_size == 0 + + grid = ( + hidden_dim_split_block_num, + BLOCK_NUM_PER_EXPERT, + expert_num, + ) + + finfo = torch.finfo(torch.float8_e4m3fn) + fp8_max = finfo.max + fp8_min = -fp8_max + + _silu_and_mul_post_quant_kernel[grid]( + input, + *input.stride(), + output, + *output.stride(), + output_scale, + *output_scale.stride(), + masked_m, + size_n, + fp8_max, + fp8_min, + BLOCK_N=BLOCK_N, + NUM_STAGE=NUM_STAGES, + num_warps=num_warps, + ) + return + + @triton.jit def tanh(x): return 2 * tl.sigmoid(2 * x) - 1 diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index f0595bfb1..814dc469e 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -3,12 +3,16 @@ from typing import Callable, List, Optional, Tuple import torch -# TODO: use deep_gemm masked kernel after low latency dispatch -# import deep_gemm -# from deep_gemm import ( -# get_col_major_tma_aligned_tensor, -# m_grouped_gemm_fp8_fp8_bf16_nt_masked, -# ) +try: + from deep_gemm import ( + get_col_major_tma_aligned_tensor, + m_grouped_gemm_fp8_fp8_bf16_nt_masked, + ) + + use_deep_gemm = True +except ImportError: + use_deep_gemm = False + from torch.nn import Module from sglang.srt.custom_op import CustomOp @@ -22,6 +26,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( post_reorder_triton_kernel, pre_reorder_triton_kernel, run_moe_ep_preproess, + silu_and_mul_masked_post_quant_fwd, silu_and_mul_triton_kernel, ) from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported @@ -809,6 +814,7 @@ class DeepEPMoE(EPMoE): correction_bias: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, activation: str = "silu", + deepep_mode: str = "auto", ): super().__init__( num_experts, @@ -827,21 +833,41 @@ class DeepEPMoE(EPMoE): custom_routing_function, activation, ) + self.deepep_mode = deepep_mode + if self.deepep_mode in ["low_latency", "auto"]: + assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm" + self.w13_weight_fp8 = ( + self.w13_weight, + ( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + ) + self.w2_weight_fp8 = ( + self.w2_weight, + self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale, + ) def forward( self, hidden_states: torch.Tensor, reorder_topk_ids: torch.Tensor, seg_indptr: torch.Tensor, + masked_m: torch.Tensor, + expected_m: int, forward_mode: ForwardMode, ): - # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode) - if True: # not forward_mode.is_decode(): + if self.deepep_mode == "normal" or ( + self.deepep_mode == "auto" and not forward_mode.is_decode() + ): return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) + elif self.deepep_mode == "low_latency" or ( + self.deepep_mode == "auto" and forward_mode.is_decode() + ): + return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) else: - return self.forward_deepgemm_masked( - hidden_states, reorder_topk_ids, seg_indptr - ) + raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") def forward_normal( self, @@ -958,89 +984,66 @@ class DeepEPMoE(EPMoE): def forward_deepgemm_masked( self, - hidden_states: torch.Tensor, - reorder_topk_ids: torch.Tensor, - seg_indptr: torch.Tensor, + hidden_states_fp8: Tuple[torch.Tensor, torch.Tensor], + masked_m: torch.Tensor, + expected_m: int, ): assert self.quant_method is not None assert self.activation == "silu" - - if self.activation_scheme == "dynamic" and not self.use_block_quant: - max_value = ( - torch.max(hidden_states) - .repeat(self.num_experts_per_partition) - .to(torch.float32) - ) - self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + assert ( + hidden_states_fp8[0].size(0) % 4 == 0 + ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}" # GroupGemm-0 + num_groups, m, k = hidden_states_fp8[0].size() + n = self.w13_weight.size(1) + expected_m = min(expected_m, m) gateup_output = torch.empty( - hidden_states.shape[0], - self.w13_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, + (num_groups, m, n), device=hidden_states_fp8[0].device, dtype=torch.bfloat16 + ) + m_grouped_gemm_fp8_fp8_bf16_nt_masked( + hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m ) - if hidden_states.shape[0] > 0: - # Transpose earlier so that the testing will not trigger transposing kernels - hidden_states = ( - hidden_states[0], - get_col_major_tma_aligned_tensor(hidden_states[1]), - ) - """ - gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - hidden_states, self.w13_weight, out, masked_m, expected_m - ) - """ # Act down_input = torch.empty( - gateup_output.shape[0], - gateup_output.shape[1] // 2, - device=gateup_output.device, - dtype=( - self.fp8_dtype - if (self.use_fp8_w8a8 and not self.use_block_quant) - else hidden_states.dtype - ), - ) - if self.w2_input_scale is None and not self.use_block_quant: - self.w2_input_scale = torch.ones( - self.num_experts_per_partition, - dtype=torch.float32, - device=hidden_states.device, - ) - - if self.activation == "silu": - silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( - gateup_output, - down_input, + ( + gateup_output.shape[0], gateup_output.shape[1], - reorder_topk_ids, - self.w2_input_scale, - 0, - self.num_experts_per_partition - 1, - BLOCK_SIZE=512, - ) - else: - raise ValueError(f"Unsupported activation: {self.activation=}") + gateup_output.shape[2] // 2, + ), + device=gateup_output.device, + dtype=self.fp8_dtype, + ) + scale_block_size = 128 + down_input_scale = torch.empty( + ( + gateup_output.shape[0], + gateup_output.shape[1], + gateup_output.shape[2] // 2 // scale_block_size, + ), + device=gateup_output.device, + dtype=torch.float32, + ) + silu_and_mul_masked_post_quant_fwd( + gateup_output, + down_input, + down_input_scale, + scale_block_size, + masked_m, + ) # GroupGemm-1 - down_output = torch.empty( - down_input.shape[0], - self.w2_weight.shape[1], - device=hidden_states.device, - dtype=hidden_states.dtype, + n = self.w2_weight.size(1) + down_input_fp8 = ( + down_input, + get_col_major_tma_aligned_tensor(down_input_scale), + ) + down_output = torch.empty( + (num_groups, m, n), device=down_input.device, dtype=torch.bfloat16 + ) + m_grouped_gemm_fp8_fp8_bf16_nt_masked( + down_input_fp8, self.w2_weight_fp8, down_output, masked_m, expected_m ) - if down_input.shape[0] > 0: - # Transpose earlier so that the testing will not trigger transposing kernels - down_input = ( - down_input[0], - get_col_major_tma_aligned_tensor(down_input[1]), - ) - """ - down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( - down_input, self.w2_weight, out, masked_m, expected_m - ) - """ return down_output diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 6b67f6cea..f4e673535 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -76,8 +76,7 @@ def get_buffer_low_latency( assert num_experts % group.size() == 0 _buffer_low_latency = Buffer( group, - 0, - num_rdma_bytes, + num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_experts // group.size(), ) @@ -95,62 +94,63 @@ class DeepEPDispatcher: group: torch.distributed.ProcessGroup, router_topk: int, permute_fusion: bool = False, - capacity_factor: float = None, num_experts: int = None, num_local_experts: int = None, hidden_size: int = None, params_dtype: torch.dtype = None, + deepep_mode: str = "auto", async_finish: bool = False, + return_recv_hook: bool = False, ): - self.group = group - self.router_topk = router_topk - self.capacity_factor = capacity_factor - self.permute_fusion = permute_fusion - self.num_experts = num_experts - self.num_local_experts = num_local_experts - self.hidden_size = hidden_size - self.recv_expert_count = None - self.params_dtype = params_dtype - self.params_bytes = 2 - # Metadata - self.token_indices = None - self.token_probs = None - # Handle used for combine operation - self.handle = None - self.async_finish = async_finish - - # `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 - # https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding - self.num_max_dispatch_tokens_per_rank = 128 - if not use_deepep: raise ImportError( "DeepEP is not installed. Please install DeepEP package from " "https://github.com/deepseek-ai/deepep." ) - self.buffer_normal = get_buffer_normal( - self.group, self.hidden_size * self.params_bytes - ) - self.buffer_low_latency = None - # Todo: enable low latency dispatch - """ - self.buffer_low_latency = get_buffer_low_latency( - self.group, - self.num_max_dispatch_tokens_per_rank, - self.hidden_size * self.params_bytes, - self.num_experts, - ) - """ + + self.group = group + self.router_topk = router_topk + self.permute_fusion = permute_fusion + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.params_dtype = params_dtype + self.params_bytes = 2 + + self.deepep_mode = deepep_mode + self.handle = None + + if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode + self.buffer_normal = get_buffer_normal( + self.group, self.hidden_size * self.params_bytes + ) + self.async_finish = async_finish + self.src2dst = None + if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode + """ + num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 + https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding + """ + # TODO(ch-wan): allow users to set this value + self.num_max_dispatch_tokens_per_rank = 128 + self.buffer_low_latency = get_buffer_low_latency( + self.group, + self.num_max_dispatch_tokens_per_rank, + self.hidden_size, + self.num_experts, + ) + self.return_recv_hook = return_recv_hook def deepep_permute( self, - hidden_states, - fp8_dtype=None, - use_fp8_w8a8=False, - use_block_quant=False, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + fp8_dtype: Optional[torch.dtype] = None, + use_fp8_w8a8: bool = False, + use_block_quant: bool = False, ): - reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess( - self.topk_idx, self.num_experts + reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess( + topk_idx, self.num_experts ) num_total_tokens = reorder_topk_ids.numel() gateup_input = torch.empty( @@ -166,14 +166,13 @@ class DeepEPDispatcher: deepep_permute_triton_kernel[(hidden_states.shape[0],)]( hidden_states, gateup_input, - src2dst, - self.topk_idx, + self.src2dst, + topk_idx, None, self.router_topk, hidden_states.shape[1], BLOCK_SIZE=512, ) - self.src2dst = src2dst return reorder_topk_ids, seg_indptr, gateup_input def dispatch( @@ -182,54 +181,64 @@ class DeepEPDispatcher: topk_idx: torch.Tensor, topk_weights: torch.Tensor, num_experts: int, - forward_mode: ForwardMode, num_max_dispatch_tokens_per_rank: int = 128, - ) -> Tuple[torch.Tensor, torch.Tensor]: + forward_mode: ForwardMode = None, + ) -> Tuple: topk_idx = topk_idx.to(torch.int64) - # Todo: enable low latency dispatch - if True: # not forward_mode.is_decode(): + reorder_topk_ids = torch.empty( + (0,), device=hidden_states.device, dtype=torch.int64 + ) + seg_indptr = torch.zeros( + (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 + ) + masked_m = torch.empty( + (self.num_local_experts,), device=hidden_states.device, dtype=torch.int64 + ) + expected_m = 0 + + if self.deepep_mode == "normal" or ( + self.deepep_mode == "auto" and not forward_mode.is_decode() + ): ( hidden_states, topk_idx, topk_weights, - num_recv_tokens_per_expert_list, - handle, event, ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) - self.tokens_per_expert = torch.tensor( - num_recv_tokens_per_expert_list, - device=hidden_states.device, - dtype=torch.int64, - ) - else: - hidden_states, recv_expert_count, handle, event, hook = ( - self.dispatch_low_latency( - hidden_states, - topk_idx, - num_max_dispatch_tokens_per_rank, - num_experts, + event.current_stream_wait() if self.async_finish else () + if hidden_states.shape[0] > 0: + reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( + hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) + elif self.deepep_mode == "low_latency" or ( + self.deepep_mode == "auto" and forward_mode.is_decode() + ): + expected_m = ( + hidden_states.shape[0] + * self.buffer_low_latency.group_size + * topk_idx.shape[1] + + num_experts + ) // num_experts + hidden_states, masked_m, event, hook = self.dispatch_low_latency( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + use_fp8=True, ) - self.recv_expert_count = recv_expert_count - - if self.async_finish: - event.current_stream_wait() - - self.handle = handle - self.topk_idx = topk_idx - self.topk_weights = topk_weights - if hidden_states.shape[0] > 0: - reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( - hidden_states, fp8_dtype=hidden_states.dtype - ) + hook() if self.return_recv_hook else event.current_stream_wait() else: - reorder_topk_ids = torch.empty( - (0,), device=hidden_states.device, dtype=torch.int64 - ) - seg_indptr = torch.zeros( - (num_experts + 1,), device=hidden_states.device, dtype=torch.int64 - ) - return hidden_states, reorder_topk_ids, seg_indptr + raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") + + return ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + ) def dispatch_normal( self, @@ -254,12 +263,15 @@ class DeepEPDispatcher: allocate_on_comm_stream=previous_event is not None, ) + # FIXME: `handle` should be transmitted with tokens from dispatch to combine. + # However, doing this would incur an unknown synchronization error, but keeping + # `handle` as a member variable works. ( recv_x, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, - handle, + _, # num_recv_tokens_per_expert_list + self.handle, event, ) = self.buffer_normal.dispatch( x, @@ -278,8 +290,6 @@ class DeepEPDispatcher: recv_x, recv_topk_idx, recv_topk_weights, - num_recv_tokens_per_expert_list, - handle, event, ) @@ -289,18 +299,19 @@ class DeepEPDispatcher: topk_idx: torch.Tensor, num_max_dispatch_tokens_per_rank: int, num_experts: int, + use_fp8: bool = False, ): """ - # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch' - # Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall! + # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'. + # Please make sure to change DeepEP code in internode_ll.cu dispatch / combine as below first and then reinstall. # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782 - + + diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu - index f60e933..cddaabf 100644 + index 76ae2e2..8ecd08f 100644 --- a/csrc/kernels/internode_ll.cu +++ b/csrc/kernels/internode_ll.cu - @@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, - int num_topk, int num_experts, int rank, int num_ranks, + @@ -310,8 +310,8 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, + int num_topk, int num_experts, int rank, int num_ranks, bool use_fp8, void* workspace, cudaStream_t stream, int phases) { constexpr int kNumMaxTopK = 9; - constexpr int kNumWarpsPerGroup = 10; @@ -308,16 +319,9 @@ class DeepEPDispatcher: + constexpr int kNumWarpsPerGroup = 8; + constexpr int kNumWarpGroups = 4; EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); - + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; - const auto num_sms = cell_div(num_experts, kNumWarpGroups); - EP_HOST_ASSERT(num_topk <= kNumMaxTopK); - - EP_HOST_ASSERT(cell_div(static_cast(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2); - + // EP_HOST_ASSERT(cell_div(static_cast(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2); - + - // Workspace checks - auto atomic_counter_per_expert = reinterpret_cast(workspace); - @@ -505,8 +505,8 @@ void combine(void* combined_x, + @@ -501,8 +501,8 @@ void combine(void* combined_x, int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, int num_topk, int num_experts, int rank, int num_ranks, void* workspace, cudaStream_t stream, int phases) { @@ -326,28 +330,33 @@ class DeepEPDispatcher: + constexpr int kNumWarpsPerGroup = 8; + constexpr int kNumWarpGroups = 4; constexpr int kNumMaxTopk = 9; - + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; """ - recv_hidden_states, recv_expert_count, handle, event, hook = ( + packed_recv_hidden, packed_recv_count, self.handle, event, hook = ( self.buffer_low_latency.low_latency_dispatch( hidden_states, topk_idx, num_max_dispatch_tokens_per_rank, num_experts, - async_finish=self.async_finish, - return_recv_hook=False, # True for double-batch overlapping, need call hook() + use_fp8=use_fp8, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, ) ) - # hook() - return recv_hidden_states, recv_expert_count, handle, event, hook + return packed_recv_hidden, packed_recv_count, event, hook def combine( - self, hidden_states: torch.Tensor, forward_mode: ForwardMode - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Todo: enable low latency combine - if True: # not forward_mode.is_decode(): + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + forward_mode: ForwardMode, + ) -> torch.Tensor: + if self.deepep_mode == "normal" or ( + self.deepep_mode == "auto" and not forward_mode.is_decode() + ): if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk output = torch.empty( @@ -359,8 +368,8 @@ class DeepEPDispatcher: hidden_states, output, self.src2dst, - self.topk_idx, - self.topk_weights, + topk_idx, + topk_weights, self.router_topk, hidden_states.shape[1], BLOCK_SIZE=512, @@ -371,24 +380,30 @@ class DeepEPDispatcher: device=hidden_states.device, dtype=hidden_states.dtype, ) - hidden_states, event = self.combine_normal(output, self.handle) - else: - hidden_states, event, hook = self.combine_low_latency( - hidden_states, self.topk_idx, self.topk_weights, self.handle + hidden_states, event = self.combine_normal( + output, ) + event.current_stream_wait() if self.async_finish else () + elif self.deepep_mode == "low_latency" or ( + self.deepep_mode == "auto" and forward_mode.is_decode() + ): + hidden_states, event, hook = self.combine_low_latency( + hidden_states, + topk_idx, + topk_weights, + ) + hook() if self.return_recv_hook else event.current_stream_wait() + else: + raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") - if self.async_finish: - event.current_stream_wait() - - self.handle = None return hidden_states - def combine_normal(self, x: torch.Tensor, handle: Tuple): + def combine_normal(self, x: torch.Tensor): previous_event = Buffer.capture() if self.async_finish else None combined_x, _, event = self.buffer_normal.combine( x, - handle, + self.handle, async_finish=self.async_finish, previous_event=previous_event, allocate_on_comm_stream=previous_event is not None, @@ -400,17 +415,15 @@ class DeepEPDispatcher: hidden_states: torch.Tensor, topk_idx: torch.Tensor, topk_weights: torch.Tensor, - handle: Tuple, ): - combined_hidden_states, event_overlap, hook = ( + combined_hidden_states, event, hook = ( self.buffer_low_latency.low_latency_combine( hidden_states, topk_idx, topk_weights, - handle, - async_finish=self.async_finish, - return_recv_hook=False, # True for double-batch overlapping, need call hook() + self.handle, + async_finish=not self.return_recv_hook, + return_recv_hook=self.return_recv_hook, ) ) - # hook() - return combined_hidden_states, event_overlap, hook + return combined_hidden_states, event, hook diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ab8b81602..991ec0551 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -72,6 +72,7 @@ global_server_args_dict = { "enable_dp_attention": ServerArgs.enable_dp_attention, "enable_ep_moe": ServerArgs.enable_ep_moe, "enable_deepep_moe": ServerArgs.enable_deepep_moe, + "deepep_mode": ServerArgs.deepep_mode, "device": ServerArgs.device, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f5405c9af..f42ea02d5 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -147,6 +147,7 @@ class ModelRunner: "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, "enable_deepep_moe": server_args.enable_deepep_moe, + "deepep_mode": server_args.deepep_mode, "device": server_args.device, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, @@ -272,7 +273,7 @@ class ModelRunner: server_args.disable_radix_cache = True if server_args.enable_deepep_moe: - logger.info("DeepEP is turned on.") + logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}") def init_torch_distributed(self): logger.info("Init torch distributed begin.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 37760407b..6aaa3744a 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -188,19 +188,35 @@ class DeepseekV2MoE(nn.Module): if global_server_args_dict["enable_deepep_moe"] else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) ) - self.experts = MoEImpl( - num_experts=config.n_routed_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - prefix=add_prefix("experts", prefix), - ) + if not global_server_args_dict["enable_deepep_moe"]: + self.experts = MoEImpl( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + prefix=add_prefix("experts", prefix), + ) + else: + self.experts = MoEImpl( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + prefix=add_prefix("experts", prefix), + deepep_mode=global_server_args_dict["deepep_mode"], + ) if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts @@ -227,6 +243,8 @@ class DeepseekV2MoE(nn.Module): ) if global_server_args_dict["enable_deepep_moe"]: + # TODO: we will support tp < ep in the future + self.ep_size = get_tensor_model_parallel_world_size() self.num_experts = config.n_routed_experts self.top_k = config.num_experts_per_tok self.renormalize = config.norm_topk_prob @@ -246,7 +264,9 @@ class DeepseekV2MoE(nn.Module): num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, + deepep_mode=global_server_args_dict["deepep_mode"], async_finish=True, # TODO + return_recv_hook=True, ) def forward( @@ -301,28 +321,39 @@ class DeepseekV2MoE(nn.Module): num_expert_group=self.num_expert_group, correction_bias=self.correction_bias, ) - if self.tp_size > 1: - recv_hidden_states, reorder_topk_ids, seg_indptr = ( - self.deepep_dispatcher.dispatch( - hidden_states, - topk_idx, - topk_weights, - self.num_experts, - forward_mode, - ) + if self.ep_size > 1: + ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + seg_indptr, + masked_m, + expected_m, + ) = self.deepep_dispatcher.dispatch( + hidden_states, + topk_idx, + topk_weights, + self.num_experts, + forward_mode=forward_mode, ) final_hidden_states = ( self.experts( - hidden_states=recv_hidden_states, + hidden_states=hidden_states, reorder_topk_ids=reorder_topk_ids, seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, forward_mode=forward_mode, ) * self.routed_scaling_factor ) - if self.tp_size > 1: + if self.ep_size > 1: final_hidden_states = self.deepep_dispatcher.combine( - final_hidden_states, forward_mode + final_hidden_states, + topk_idx, + topk_weights, + forward_mode, ) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6f4725487..1a19bbea2 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -161,6 +161,7 @@ class ServerArgs: enable_dp_attention: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False + deepep_mode: Optional[str] = "auto" enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -285,6 +286,13 @@ class ServerArgs: if self.grammar_backend is None: self.grammar_backend = "xgrammar" + # Expert parallelism + if self.enable_ep_moe: + self.ep_size = self.tp_size + logger.info( + f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." + ) + # Data parallelism attention if self.enable_dp_attention: self.schedule_conservativeness = self.schedule_conservativeness * 0.3 @@ -300,6 +308,10 @@ class ServerArgs: self.enable_sp_layernorm = False # DeepEP MoE if self.enable_deepep_moe: + if self.deepep_mode == "auto": + assert ( + not self.enable_dp_attention + ), "DeepEP MoE `auto` mode is not supported with DP Attention." self.ep_size = self.tp_size self.enable_sp_layernorm = ( self.dp_size < self.tp_size if self.enable_dp_attention else True @@ -1082,6 +1094,12 @@ class ServerArgs: action="store_true", help="Enabling DeepEP MoE implementation for EP MoE.", ) + parser.add_argument( + "--deepep-mode", + type=str, + choices=["normal", "low_latency", "auto"], + help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.", + ) # Server warmups parser.add_argument(