From f44db16c8e0fbee1b964e802f1ab493afb6f7996 Mon Sep 17 00:00:00 2001 From: Jinyan Chen <93358689+liz-badada@users.noreply.github.com> Date: Wed, 19 Mar 2025 23:16:31 +0800 Subject: [PATCH] [Feature] Integrate DeepEP into SGLang (#4232) Co-authored-by: Cheng Wan Co-authored-by: Xuting Zhou --- docker/Dockerfile.deepep | 77 +++ docs/backend/server_arguments.md | 1 + python/sglang/srt/layers/linear.py | 15 +- .../sglang/srt/layers/moe/ep_moe/kernels.py | 123 +++- python/sglang/srt/layers/moe/ep_moe/layer.py | 274 +++++++++ .../srt/layers/moe/ep_moe/token_dispatcher.py | 533 ++++++++++++++++++ python/sglang/srt/layers/parameter.py | 2 +- python/sglang/srt/managers/schedule_batch.py | 1 + .../sglang/srt/model_executor/model_runner.py | 7 + python/sglang/srt/models/deepseek_v2.py | 165 +++++- python/sglang/srt/server_args.py | 12 + test/srt/test_moe_deepep.py | 53 ++ 12 files changed, 1228 insertions(+), 35 deletions(-) create mode 100644 docker/Dockerfile.deepep create mode 100644 python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py mode change 100755 => 100644 python/sglang/srt/models/deepseek_v2.py create mode 100644 test/srt/test_moe_deepep.py diff --git a/docker/Dockerfile.deepep b/docker/Dockerfile.deepep new file mode 100644 index 000000000..bd32e37d4 --- /dev/null +++ b/docker/Dockerfile.deepep @@ -0,0 +1,77 @@ +FROM lmsysorg/sglang:latest + +# CMake +RUN apt-get update \ +&& apt-get install -y --no-install-recommends \ +build-essential \ +wget \ +libssl-dev \ +&& wget https://github.com/Kitware/CMake/releases/download/v3.27.4/cmake-3.27.4-linux-x86_64.sh \ +&& chmod +x cmake-3.27.4-linux-x86_64.sh \ +&& ./cmake-3.27.4-linux-x86_64.sh --skip-license --prefix=/usr/local \ +&& rm cmake-3.27.4-linux-x86_64.sh + +# Python +RUN apt-get update \ + && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + && ln -s /usr/bin/python3 /usr/bin/python + +# GDRCopy +WORKDIR /tmp +RUN git clone https://github.com/NVIDIA/gdrcopy.git +WORKDIR /tmp/gdrcopy +RUN git checkout v2.4.4 + +RUN apt update +RUN apt install -y nvidia-dkms-535 +RUN apt install -y build-essential devscripts debhelper fakeroot pkg-config dkms +RUN apt install -y check libsubunit0 libsubunit-dev + +WORKDIR /tmp/gdrcopy/packages +RUN CUDA=/usr/local/cuda ./build-deb-packages.sh +RUN dpkg -i gdrdrv-dkms_*.deb +RUN dpkg -i libgdrapi_*.deb +RUN dpkg -i gdrcopy-tests_*.deb +RUN dpkg -i gdrcopy_*.deb + +ENV GDRCOPY_HOME=/usr/src/gdrdrv-2.4.4/ + +# IBGDA dependency +RUN ln -s /usr/lib/x86_64-linux-gnu/libmlx5.so.1 /usr/lib/x86_64-linux-gnu/libmlx5.so +RUN apt-get install -y libfabric-dev + +# DeepEP +WORKDIR /sgl-workspace +RUN git clone https://github.com/deepseek-ai/DeepEP.git + +# NVSHMEM +WORKDIR /sgl-workspace +RUN wget https://developer.download.nvidia.com/compute/redist/nvshmem/3.2.5/source/nvshmem_src_3.2.5-1.txz +RUN tar -xf nvshmem_src_3.2.5-1.txz \ + && mv nvshmem_src nvshmem + +WORKDIR /sgl-workspace/nvshmem +RUN git apply /sgl-workspace/DeepEP/third-party/nvshmem.patch + +WORKDIR /sgl-workspace/nvshmem +ENV CUDA_HOME=/usr/local/cuda +RUN NVSHMEM_SHMEM_SUPPORT=0 \ + NVSHMEM_UCX_SUPPORT=0 \ + NVSHMEM_USE_NCCL=0 \ + NVSHMEM_MPI_SUPPORT=0 \ + NVSHMEM_IBGDA_SUPPORT=1 \ + NVSHMEM_PMIX_SUPPORT=0 \ + NVSHMEM_TIMEOUT_DEVICE_POLLING=0 \ + NVSHMEM_USE_GDRCOPY=1 \ + cmake -S . -B build/ -DCMAKE_INSTALL_PREFIX=/sgl-workspace/nvshmem/install -DCMAKE_CUDA_ARCHITECTURES=90 \ + && cd build \ + && make install -j + +WORKDIR /sgl-workspace/DeepEP +ENV NVSHMEM_DIR=/sgl-workspace/nvshmem/install +RUN NVSHMEM_DIR=/sgl-workspace/nvshmem/install python setup.py install + +# Set workspace +WORKDIR /sgl-workspace diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 1749a6e72..c2e81eafe 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -89,6 +89,7 @@ Please consult the documentation below to learn more about the parameters you ma ### Expert parallelism * `enable_ep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for MoE models. * `ep_size`: The size of EP. Please shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). If not set, `ep_size` will be automatically set to `tp_size`. +* `enable_deepep_moe`: Enables expert parallelism that distributes the experts onto multiple GPUs for DeepSeek-V3 model based on deepseek-ai/DeepEP. Currently DeepEP is bind to DP Attention. Please set `--enable-dp-attention --enable-deepep-moe`, perfer `tp_size=dp_size=ep_size`. ## Memory and scheduling diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index a5bb6281d..32bcf1572 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -687,10 +687,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ): if loaded_shard_id is None: if isinstance(param, PerTensorScaleParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) + param.load_merged_column_weight( + loaded_weight=loaded_weight, + shard_id=0, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + ) return elif type(param) in (RowvLLMParameter, BasevLLMParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight) + param.load_merged_column_weight( + loaded_weight=loaded_weight, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + ) return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) @@ -719,6 +728,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): shard_offset=shard_offset, shard_size=shard_size, use_presharded_weights=self.use_presharded_weights, + tp_rank=self.tp_rank, + tp_size=self.tp_size, ) diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index b455c05c3..6d6c432f8 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -5,6 +5,7 @@ import torch import triton import triton.language as tl +from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.utils import is_cuda @@ -16,6 +17,117 @@ if _is_cuda: logger = logging.getLogger(__name__) +@triton.jit +def compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + tl.store(src2dst + src_id, dst_id, mask=mask) + + +@triton.jit +def deepep_compute_src2dst_triton_kernel( + reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = dst_id < num_toks + src_id = tl.load(reorder_ids + dst_id, mask=mask) + num_invalid = tl.load(num_minus_one) + tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask) + + +def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int): + reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) + seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) + src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32) + + # Find offet + expert_ids = torch.arange( + num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype + ) + torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr) + num_minus_one = seg_indptr[0] + seg_indptr = seg_indptr - num_minus_one + + BLOCK_SIZE = 512 + grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),) + deepep_compute_src2dst_triton_kernel[grid]( + reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE + ) + + reorder_topk_ids = reorder_topk_ids[num_minus_one:] + return reorder_topk_ids, src2dst, seg_indptr + + +@triton.jit +def deepep_permute_triton_kernel( + input_ptr, + gateup_input_ptr, + src2dst_ptr, + topk_ids_ptr, + a1_scales_ptr, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + OutDtype = gateup_input_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + + src_ptr = input_ptr + src_idx * hidden_size + + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32) + + for idx in range(topk): + dst_idx = tl.load(src2dst_ptr + idx) + if dst_idx >= 0: + dst_ptr = gateup_input_ptr + dst_idx * hidden_size + out_data = (in_data).to(OutDtype) + tl.store(dst_ptr + offset, out_data, mask=mask) + + +@triton.jit +def deepep_post_reorder_triton_kernel( + down_output_ptr, + output_ptr, + src2dst_ptr, + topk_ids_ptr, + topk_weights_ptr, + topk, + hidden_size, + BLOCK_SIZE: tl.constexpr, +): + InDtype = down_output_ptr.dtype.element_ty + + src_idx = tl.program_id(0) + src2dst_ptr = src2dst_ptr + src_idx * topk + topk_ids_ptr = topk_ids_ptr + src_idx * topk + topk_weights_ptr = topk_weights_ptr + src_idx * topk + + store_ptr = output_ptr + src_idx * hidden_size + for start_offset in tl.range(0, hidden_size, BLOCK_SIZE): + offset = start_offset + tl.arange(0, BLOCK_SIZE) + mask = offset < hidden_size + sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype) + for idx in range(topk): + dst_idx = tl.load(src2dst_ptr + idx) + if dst_idx >= 0: + weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype) + load_ptr = down_output_ptr + dst_idx * hidden_size + in_data = tl.load(load_ptr + offset, mask=mask) + sum_vec += in_data * weigh_scale + tl.store(store_ptr + offset, sum_vec, mask=mask) + + @triton.jit def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): expert = tl.program_id(0) @@ -33,17 +145,6 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks): tl.store(seg_indptr + expert + 1, target_location + 1) -@triton.jit -def compute_src2dst_triton_kernel( - reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr -): - pid = tl.program_id(axis=0) - dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = dst_id < num_toks - src_id = tl.load(reorder_ids + dst_id, mask=mask) - tl.store(src2dst + src_id, dst_id, mask=mask) - - def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int): reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True) seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64) diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index e862660d7..a9b443a75 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -2,6 +2,13 @@ import logging from typing import Callable, List, Optional, Tuple import torch + +# TODO: use deep_gemm masked kernel after low latency dispatch +# import deep_gemm +# from deep_gemm import ( +# get_col_major_tma_aligned_tensor, +# m_grouped_gemm_fp8_fp8_bf16_nt_masked, +# ) from torch.nn import Module from sglang.srt.custom_op import CustomOp @@ -25,6 +32,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod +from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs _is_cuda = is_cuda() @@ -39,6 +47,8 @@ logger = logging.getLogger(__name__) _is_hip = is_hip() +_buffer = None + class GroupedGemmRunner(torch.nn.Module): flashinfer_gemm_warpper = None @@ -773,3 +783,267 @@ class Fp8EPMoEMethod(Fp8MoEMethod): custom_routing_function: Optional[Callable] = None, ) -> torch.Tensor: raise NotImplementedError + + +class DeepEPMoE(EPMoE): + """ + MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main) + """ + + _has_printed = False + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = "", + correction_bias: Optional[torch.Tensor] = None, + custom_routing_function: Optional[Callable] = None, + activation: str = "silu", + ): + super().__init__( + num_experts, + top_k, + hidden_size, + intermediate_size, + params_dtype, + renormalize, + use_grouped_topk, + num_expert_group, + topk_group, + quant_config, + tp_size, + prefix, + correction_bias, + custom_routing_function, + activation, + ) + + def forward( + self, + hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + forward_mode: ForwardMode, + ): + # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode) + if True: # not forward_mode.is_decode(): + return self.forward_normal(hidden_states, tokens_per_expert) + else: + return self.forward_deepgemm_masked(hidden_states, tokens_per_expert) + + def forward_normal( + self, + hidden_states: torch.Tensor, + tokens_per_expert: torch.Tensor, + ): + assert self.quant_method is not None + assert self.activation == "silu" + if self.grouped_gemm_runner is None: + self.grouped_gemm_runner = GroupedGemmRunner( + hidden_states.device, use_flashinfer=False # TODO: use flashinfer + ) + seg_indptr_cur_rank = torch.cat( + [ + torch.zeros( + 1, device=tokens_per_expert.device, dtype=tokens_per_expert.dtype + ), + torch.cumsum(tokens_per_expert, dim=0), + ] + ) + reorder_topk_ids = torch.repeat_interleave(tokens_per_expert) + if self.activation_scheme == "dynamic" and not self.use_block_quant: + max_value = ( + torch.max(hidden_states) + .repeat(self.num_experts_per_partition) + .to(torch.float32) + ) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + weight_indices_cur_rank = torch.arange( + 0, + self.num_experts_per_partition, + device=hidden_states.device, + dtype=torch.int64, + ) + + # GroupGemm-0 + gateup_output = torch.empty( + hidden_states.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if hidden_states.shape[0] > 0: + gateup_output = self.grouped_gemm_runner( + a=hidden_states, + b=self.w13_weight, + c=gateup_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w13_input_scale, + scale_b=( + self.w13_weight_scale_inv + if self.use_block_quant + else self.w13_weight_scale + ), + block_shape=self.block_shape, + ) + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=( + self.fp8_dtype + if (self.use_fp8_w8a8 and not self.use_block_quant) + else hidden_states.dtype + ), + ) + if self.w2_input_scale is None and not self.use_block_quant: + self.w2_input_scale = torch.ones( + self.num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + + if self.activation == "silu": + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + 0, + self.num_experts_per_partition - 1, + BLOCK_SIZE=512, + ) + else: + raise ValueError(f"Unsupported activation: {self.activation=}") + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + self.w2_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if down_input.shape[0] > 0: + down_output = self.grouped_gemm_runner( + a=down_input, + b=self.w2_weight, + c=down_output, + batch_size=self.num_experts_per_partition, + weight_column_major=True, + seg_indptr=seg_indptr_cur_rank, + weight_indices=weight_indices_cur_rank, + use_fp8_w8a8=self.use_fp8_w8a8, + scale_a=self.w2_input_scale, + scale_b=( + self.w2_weight_scale_inv + if self.use_block_quant + else self.w2_weight_scale + ), + block_shape=self.block_shape, + ) + return down_output + + def forward_deepgemm_masked( + self, + hidden_states: torch.Tensor, + reorder_topk_ids: torch.Tensor, + seg_indptr: torch.Tensor, + ): + assert self.quant_method is not None + assert self.activation == "silu" + + if self.activation_scheme == "dynamic" and not self.use_block_quant: + max_value = ( + torch.max(hidden_states) + .repeat(self.num_experts_per_partition) + .to(torch.float32) + ) + self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max + + # GroupGemm-0 + gateup_output = torch.empty( + hidden_states.shape[0], + self.w13_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if hidden_states.shape[0] > 0: + # Transpose earlier so that the testing will not trigger transposing kernels + hidden_states = ( + hidden_states[0], + get_col_major_tma_aligned_tensor(hidden_states[1]), + ) + """ + gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + hidden_states, self.w13_weight, out, masked_m, expected_m + ) + """ + + # Act + down_input = torch.empty( + gateup_output.shape[0], + gateup_output.shape[1] // 2, + device=gateup_output.device, + dtype=( + self.fp8_dtype + if (self.use_fp8_w8a8 and not self.use_block_quant) + else hidden_states.dtype + ), + ) + if self.w2_input_scale is None and not self.use_block_quant: + self.w2_input_scale = torch.ones( + self.num_experts_per_partition, + dtype=torch.float32, + device=hidden_states.device, + ) + + if self.activation == "silu": + silu_and_mul_triton_kernel[(gateup_output.shape[0],)]( + gateup_output, + down_input, + gateup_output.shape[1], + reorder_topk_ids, + self.w2_input_scale, + 0, + self.num_experts_per_partition - 1, + BLOCK_SIZE=512, + ) + else: + raise ValueError(f"Unsupported activation: {self.activation=}") + + # GroupGemm-1 + down_output = torch.empty( + down_input.shape[0], + self.w2_weight.shape[1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if down_input.shape[0] > 0: + # Transpose earlier so that the testing will not trigger transposing kernels + down_input = ( + down_input[0], + get_col_major_tma_aligned_tensor(down_input[1]), + ) + """ + down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked( + down_input, self.w2_weight, out, masked_m, expected_m + ) + """ + + return down_output diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py new file mode 100644 index 000000000..c91ccd633 --- /dev/null +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -0,0 +1,533 @@ +try: + from deep_ep import Buffer + + use_deepep = True +except ImportError: + use_deepep = False + +import os +from typing import Optional, Tuple + +import torch +import torch.distributed as dist + +from sglang.srt.layers.moe.ep_moe.kernels import ( + compute_src2dst_triton_kernel, + deepep_permute_triton_kernel, + deepep_post_reorder_triton_kernel, + deepep_run_moe_deep_preprocess, +) +from sglang.srt.model_executor.forward_batch_info import ForwardMode + +_buffer_normal = None +_buffer_low_latency = None + + +def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int): + """ + Copy from DeepEP example usage in model inference prefilling. + https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling + """ + + global _buffer_normal + + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max( + config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes + ) + num_rdma_bytes = max( + config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes + ) + + if ( + _buffer_normal is None + or _buffer_normal.group != group + or _buffer_normal.num_nvl_bytes < num_nvl_bytes + or _buffer_normal.num_rdma_bytes < num_rdma_bytes + ): + _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes) + return _buffer_normal + + +def get_buffer_low_latency( + group: dist.ProcessGroup, + num_max_dispatch_tokens_per_rank: int, + hidden: int, + num_experts: int, +): + """ + Copy from DeepEP example usage in model inference decoding. + https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding + """ + + global _buffer_low_latency + num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts + ) + + if ( + _buffer_low_latency is None + or _buffer_low_latency.group != group + or not _buffer_low_latency.low_latency_mode + or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes + ): + assert num_experts % group.size() == 0 + _buffer_low_latency = Buffer( + group, + 0, + num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_experts // group.size(), + ) + return _buffer_low_latency + + +def permute( + tokens, + routing_map, + num_out_tokens: Optional[int] = None, + fused: bool = False, + drop_and_pad: bool = False, +): + """ + Copy from Megatron-Core moe for token permutation + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py + """ + + num_tokens, _ = tokens.shape + num_experts = routing_map.shape[1] + if drop_and_pad and not (num_out_tokens is None): + capacity = num_out_tokens // num_experts + assert not routing_map.requires_grad + routing_map = routing_map.to(dtype=torch.int8).T.contiguous() + sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[ + :, :capacity + ].contiguous() + sorted_indices = sorted_indices.view(-1) + else: + routing_map = routing_map.bool().T.contiguous() + token_indices = ( + torch.arange(num_tokens, device=routing_map.device) + .unsqueeze(0) + .expand(num_experts, -1) + ) + sorted_indices = token_indices.masked_select(routing_map) + permuted_input = tokens.index_select(0, sorted_indices) + + return permuted_input, sorted_indices + + +def unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + restore_shape: torch.Size, + probs: torch.Tensor = None, + routing_map: torch.Tensor = None, + fused: bool = False, + drop_and_pad: bool = False, +): + """ + Copy from Megatron-Core moe for token unpermutation + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/moe_utils.py + """ + + _, hidden = restore_shape + + if probs is not None: + assert routing_map is not None, "Mask must be provided to permute the probs." + if drop_and_pad: + num_experts = routing_map.size(1) + num_permuted_tokens = sorted_indices.size(0) + capacity = num_permuted_tokens // num_experts + num_unpermuted_tokens = probs.size(0) + + probs_T_1D = probs.T.contiguous().view(-1) + + indices_dim0 = torch.arange( + num_experts, device=routing_map.device + ).unsqueeze(-1) + indices_dim1 = sorted_indices.view(num_experts, capacity) + indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1) + + permuted_probs = probs_T_1D.index_select(0, indices_1D) + else: + permuted_probs = probs.T.contiguous().masked_select( + routing_map.T.contiguous() + ) + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + + output_tokens = torch.zeros( + restore_shape, device=permuted_tokens.device, dtype=permuted_tokens.dtype + ) + output_tokens.scatter_add_( + 0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens + ) + + return output_tokens + + +class DeepEPDispatcher: + """ + Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher + https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py + """ + + def __init__( + self, + group: torch.distributed.ProcessGroup, + router_topk: int, + permute_fusion: bool = False, + capacity_factor: float = None, + num_experts: int = None, + num_local_experts: int = None, + hidden_size: int = None, + params_dtype: torch.dtype = None, + ): + self.group = group + self.router_topk = router_topk + self.capacity_factor = capacity_factor + self.permute_fusion = permute_fusion + self.num_experts = num_experts + self.num_local_experts = num_local_experts + self.hidden_size = hidden_size + self.recv_expert_count = None + self.params_dtype = params_dtype + self.params_bytes = 2 + # Metadata + self.token_indices = None + self.token_probs = None + # Handle used for combine operation + self.handle = None + + # `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 + # https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding + self.num_max_dispatch_tokens_per_rank = 128 + + if not use_deepep: + raise ImportError( + "DeepEP is not installed. Please install DeepEP package from " + "https://github.com/deepseek-ai/deepep." + ) + self.buffer_normal = get_buffer_normal( + self.group, self.hidden_size * self.params_bytes + ) + self.buffer_low_latency = None + # Todo: enable low latency dispatch + """ + self.buffer_low_latency = get_buffer_low_latency( + self.group, + self.num_max_dispatch_tokens_per_rank, + self.hidden_size * self.params_bytes, + self.num_experts, + ) + """ + + def deepep_permute( + self, + topk_ids, + hidden_states, + num_experts, + top_k, + use_fp8_w8a8, + use_block_quant, + fp8_dtype, + ): + reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess( + topk_ids, num_experts + ) + num_total_tokens = reorder_topk_ids.numel() + gateup_input = torch.empty( + (int(num_total_tokens), hidden_states.shape[1]), + device=hidden_states.device, + dtype=( + fp8_dtype + if (use_fp8_w8a8 and not use_block_quant) + else hidden_states.dtype + ), + ) + # PreReorder + deepep_permute_triton_kernel[(hidden_states.shape[0],)]( + hidden_states, + gateup_input, + src2dst, + topk_ids, + None, + top_k, + hidden_states.shape[1], + BLOCK_SIZE=512, + ) + self.src2dst = src2dst + return reorder_topk_ids, seg_indptr, gateup_input + + def dispatch( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + forward_mode: ForwardMode, + previous_event=None, + num_max_dispatch_tokens_per_rank: int = 128, + ) -> Tuple[torch.Tensor, torch.Tensor]: + self.hidden_shape = hidden_states.shape + topk_idx = topk_idx.to(torch.int64) + # Todo: enable low latency dispatch + if True: # not forward_mode.is_decode(): + ( + hidden_states, + topk_idx, + topk_weights, + num_recv_tokens_per_expert_list, + handle, + event, + ) = self.dispatch_normal( + hidden_states, topk_idx, topk_weights, num_experts, previous_event + ) + self.tokens_per_expert = torch.tensor( + num_recv_tokens_per_expert_list, + device=hidden_states.device, + dtype=torch.int64, + ) + else: + hidden_states, recv_expert_count, handle, event, hook = ( + self.dispatch_low_latency( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + ) + ) + self.recv_expert_count = recv_expert_count + tokens_per_expert = self.get_number_of_tokens_per_expert() + self.handle = handle + self.topk_idx = topk_idx + self.topk_weights = topk_weights + if hidden_states.shape[0] > 0: + hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states) + return hidden_states, topk_idx, topk_weights, tokens_per_expert + + def dispatch_normal( + self, + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_experts: int, + previous_event=None, + ): + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert, + is_token_in_rank, + previous_event, + ) = self.buffer_normal.get_dispatch_layout( + topk_idx, + num_experts, + previous_event=previous_event, + async_finish=False, + allocate_on_comm_stream=False, + ) + + ( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + handle, + event, + ) = self.buffer_normal.dispatch( + x, + topk_idx=topk_idx, + topk_weights=topk_weights, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=False, + allocate_on_comm_stream=False, + ) + + return ( + recv_x, + recv_topk_idx, + recv_topk_weights, + num_recv_tokens_per_expert_list, + handle, + event, + ) + + def dispatch_low_latency( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + num_max_dispatch_tokens_per_rank: int, + num_experts: int, + ): + """ + # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch' + # Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall! + # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782 + + + diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu + index f60e933..cddaabf 100644 + --- a/csrc/kernels/internode_ll.cu + +++ b/csrc/kernels/internode_ll.cu + @@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales, + int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, cudaStream_t stream, int phases) { + constexpr int kNumMaxTopK = 9; + - constexpr int kNumWarpsPerGroup = 10; + - constexpr int kNumWarpGroups = 3; + + constexpr int kNumWarpsPerGroup = 8; + + constexpr int kNumWarpGroups = 4; + EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections"); + + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + const auto num_sms = cell_div(num_experts, kNumWarpGroups); + EP_HOST_ASSERT(num_topk <= kNumMaxTopK); + - EP_HOST_ASSERT(cell_div(static_cast(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2); + + // EP_HOST_ASSERT(cell_div(static_cast(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2); + + + // Workspace checks + auto atomic_counter_per_expert = reinterpret_cast(workspace); + @@ -505,8 +505,8 @@ void combine(void* combined_x, + int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank, + int num_topk, int num_experts, int rank, int num_ranks, + void* workspace, cudaStream_t stream, int phases) { + - constexpr int kNumWarpsPerGroup = 10; + - constexpr int kNumWarpGroups = 3; + + constexpr int kNumWarpsPerGroup = 8; + + constexpr int kNumWarpGroups = 4; + constexpr int kNumMaxTopk = 9; + + + const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup; + """ + + recv_hidden_states, recv_expert_count, handle, event, hook = ( + self.buffer_low_latency.low_latency_dispatch( + hidden_states, + topk_idx, + num_max_dispatch_tokens_per_rank, + num_experts, + async_finish=False, + return_recv_hook=False, # True for double-batch overlapping, need call hook() + ) + ) + # hook() + return recv_hidden_states, recv_expert_count, handle, event, hook + + def combine( + self, hidden_states: torch.Tensor, forward_mode: ForwardMode + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + # Todo: enable low latency combine + if True: # not forward_mode.is_decode(): + if hidden_states.shape[0] > 0: + hidden_states = self.get_restored_hidden_states_by_experts( + hidden_states + ) + hidden_states, event = self.combine_normal(hidden_states, self.handle) + else: + hidden_states, event, hook = self.combine_low_latency( + hidden_states, self.topk_idx, self.topk_weights, self.handle + ) + self.handle = None + return hidden_states.view(self.hidden_shape) + + def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None): + combined_x, _, event = self.buffer_normal.combine( + x, + handle, + async_finish=False, + previous_event=previous_event, + allocate_on_comm_stream=False, + ) + return combined_x, event + + def combine_low_latency( + self, + hidden_states: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + handle: Tuple, + ): + combined_hidden_states, event_overlap, hook = ( + self.buffer_low_latency.low_latency_combine( + hidden_states, + topk_idx, + topk_weights, + handle, + async_finish=False, + return_recv_hook=False, # True for double-batch overlapping, need call hook() + ) + ) + # hook() + return combined_hidden_states, event_overlap, hook + + def _indices_to_multihot(self, indices, probs): + batch_size = indices.shape[0] + multihot_routing_map = torch.zeros( + (batch_size, self.num_local_experts), + dtype=torch.long, + device=indices.device, + ) + + multihot_probs = torch.zeros( + (batch_size, self.num_local_experts), + dtype=torch.float, + device=indices.device, + ) + + mask = indices != -1 + valid_indices = indices[mask] + row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave( + mask.sum(dim=1) + ) + multihot_routing_map[row_indices, valid_indices] = 1 + multihot_probs[row_indices, valid_indices] = probs[mask] + return multihot_routing_map.bool(), multihot_probs + + def get_dispached_metadata(self) -> torch.Tensor: + return self.topk_idx, self.topk_weights + + def get_number_of_tokens_per_expert(self) -> torch.Tensor: + """ + Get the number of tokens per expert. + """ + return self.tokens_per_expert + + def get_permuted_hidden_states_by_experts( + self, hidden_states: torch.Tensor + ) -> torch.Tensor: + self.dispatched_routing_map, self.topk_weights = self._indices_to_multihot( + self.topk_idx, self.topk_weights + ) + self.hidden_shape_before_permute = hidden_states.shape + hidden_states, self.reversed_mapping_for_combine = permute( + hidden_states, + self.dispatched_routing_map, + num_out_tokens=self.tokens_per_expert.sum(), + fused=self.permute_fusion, + ) + return hidden_states + + def get_restored_hidden_states_by_experts( + self, hidden_states: torch.Tensor + ) -> torch.Tensor: + input_dtype = hidden_states.dtype + assert ( + self.topk_weights.dtype == torch.float32 + ), "DeepEP only supports float32 probs" + hidden_states = unpermute( + hidden_states, + self.reversed_mapping_for_combine, + restore_shape=self.hidden_shape_before_permute, + routing_map=self.dispatched_routing_map, + probs=self.topk_weights, + fused=self.permute_fusion, + ) + return hidden_states.to(input_dtype) diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py index b3fc6b440..33bd46c05 100644 --- a/python/sglang/srt/layers/parameter.py +++ b/python/sglang/srt/layers/parameter.py @@ -105,6 +105,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): shard_offset = kwargs.get("shard_offset") shard_size = kwargs.get("shard_size") + tp_rank = kwargs.get("tp_rank") use_presharded_weights = kwargs.get("use_presharded_weights") if ( isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) @@ -116,7 +117,6 @@ class _ColumnvLLMParameter(BasevLLMParameter): param_data = self.data - tp_rank = get_tensor_model_parallel_rank() param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) if not use_presharded_weights: loaded_weight = loaded_weight.narrow( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 86656490a..3b8259bfc 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -67,6 +67,7 @@ global_server_args_dict = { "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, "enable_ep_moe": ServerArgs.enable_ep_moe, + "enable_deepep_moe": ServerArgs.enable_deepep_moe, "device": ServerArgs.device, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 6f2e05e26..eaaf2637f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -145,6 +145,7 @@ class ModelRunner: "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, + "enable_deepep_moe": server_args.enable_deepep_moe, "device": server_args.device, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, @@ -277,6 +278,12 @@ class ModelRunner: server_args.chunked_prefill_size = -1 server_args.disable_radix_cache = True + if server_args.enable_deepep_moe: + logger.info("DeepEP is turned on.") + assert ( + server_args.enable_dp_attention == True + ), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'" + def init_torch_distributed(self): logger.info("Init torch distributed begin.") diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py old mode 100755 new mode 100644 index 1cbd0097a..27a12c627 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -26,6 +26,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, + parallel_state, tensor_model_parallel_all_reduce, ) from sglang.srt.layers.activation import SiluAndMul @@ -47,8 +48,10 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import EPMoE +from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE +from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import ( block_quant_to_tensor_quant, @@ -65,7 +68,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ) from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip @@ -87,6 +90,8 @@ class DeepseekV2MLP(nn.Module): quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -95,6 +100,8 @@ class DeepseekV2MLP(nn.Module): bias=False, quant_config=quant_config, prefix=add_prefix("gate_up_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, ) self.down_proj = RowParallelLinear( intermediate_size, @@ -103,6 +110,8 @@ class DeepseekV2MLP(nn.Module): quant_config=quant_config, reduce_results=reduce_results, prefix=add_prefix("down_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, ) if hidden_act != "silu": raise ValueError( @@ -167,7 +176,11 @@ class DeepseekV2MoE(nn.Module): self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix)) - MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + MoEImpl = ( + DeepEPMoE + if global_server_args_dict["enable_deepep_moe"] + else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) + ) self.experts = MoEImpl( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, @@ -184,16 +197,59 @@ class DeepseekV2MoE(nn.Module): if config.n_shared_experts is not None: intermediate_size = config.moe_intermediate_size * config.n_shared_experts - self.shared_experts = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - reduce_results=False, - prefix=add_prefix("shared_experts", prefix), + # disable tp for shared experts when enable deepep moe + if not global_server_args_dict["enable_deepep_moe"]: + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=add_prefix("shared_experts", prefix), + ) + else: + self.shared_experts = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + prefix=add_prefix("shared_experts", prefix), + tp_rank=0, + tp_size=1, + ) + + if global_server_args_dict["enable_deepep_moe"]: + self.num_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + self.renormalize = config.norm_topk_prob + self.topk_group = config.topk_group + self.num_expert_group = config.n_group + self.correction_bias = ( + self.gate.e_score_correction_bias.data + if self.gate.e_score_correction_bias is not None + else None ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + self.deepep_dispatcher = DeepEPDispatcher( + group=parallel_state.get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=config.n_routed_experts, + num_local_experts=config.n_routed_experts // self.tp_size, + hidden_size=config.hidden_size, + params_dtype=config.torch_dtype, + ) + + def forward( + self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None + ) -> torch.Tensor: + if not global_server_args_dict["enable_deepep_moe"]: + return self.forward_normal(hidden_states) + else: + return self.forward_deepep(hidden_states, forward_mode) + + def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.n_shared_experts is not None: @@ -208,6 +264,59 @@ class DeepseekV2MoE(nn.Module): final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_dim) + + def forward_deepep( + self, hidden_states: torch.Tensor, forward_mode: ForwardMode + ) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + shared_output = None + topk_idx = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + topk_weights = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) + if forward_mode is not None and not forward_mode.is_idle(): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + topk_weights, topk_idx = select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=True, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + correction_bias=self.correction_bias, + ) + if self.tp_size > 1: + recv_hidden_states, topk_idx, topk_weights, tokens_per_expert = ( + self.deepep_dispatcher.dispatch( + hidden_states, + topk_idx, + topk_weights, + self.num_experts, + forward_mode, + ) + ) + final_hidden_states = ( + self.experts( + hidden_states=recv_hidden_states, + tokens_per_expert=tokens_per_expert, + forward_mode=forward_mode, + ) + * self.routed_scaling_factor + ) + if self.tp_size > 1: + final_hidden_states = self.deepep_dispatcher.combine( + final_hidden_states, forward_mode + ) + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output return final_hidden_states.view(num_tokens, hidden_dim) @@ -959,15 +1068,25 @@ class DeepseekV2DecoderLayer(nn.Module): if get_tensor_model_parallel_world_size() > 1: # all gather and all reduce if self.dp_size != 1: - if get_attention_tp_rank() == 0: - hidden_states += residual - hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, - hidden_states, - ) - dp_gather_partial(hidden_states, local_hidden_states, forward_batch) - dp_scatter(residual, hidden_states, forward_batch) - hidden_states = self.post_attention_layernorm(hidden_states) + if global_server_args_dict["enable_deepep_moe"] and isinstance( + self.mlp, DeepseekV2MoE + ): + if hidden_states.shape[0] != 0: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) + hidden_states = self.mlp(hidden_states, forward_batch.forward_mode) + return hidden_states, residual + else: + if get_attention_tp_rank() == 0: + hidden_states += residual + hidden_states, local_hidden_states = ( + forward_batch.gathered_buffer, + hidden_states, + ) + dp_gather_partial(hidden_states, local_hidden_states, forward_batch) + dp_scatter(residual, hidden_states, forward_batch) + hidden_states = self.post_attention_layernorm(hidden_states) else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states, residual = self.post_attention_layernorm( @@ -1099,7 +1218,11 @@ class DeepseekV2ForCausalLM(nn.Module): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE + MoEImpl = ( + DeepEPMoE + if global_server_args_dict["enable_deepep_moe"] + else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE) + ) expert_params_mapping = MoEImpl.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 251618268..144b8b6ce 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -157,6 +157,7 @@ class ServerArgs: enable_mixed_chunk: bool = False enable_dp_attention: bool = False enable_ep_moe: bool = False + enable_deepep_moe: bool = False enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None @@ -281,6 +282,12 @@ class ServerArgs: logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " ) + # DeepEP MoE + if self.enable_deepep_moe: + self.ep_size = self.dp_size + logger.info( + f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]." + ) # Speculative Decoding if self.speculative_algorithm == "NEXTN": @@ -1018,6 +1025,11 @@ class ServerArgs: default=ServerArgs.hicache_ratio, help="The ratio of the size of host KV cache memory pool to the size of device pool.", ) + parser.add_argument( + "--enable-deepep-moe", + action="store_true", + help="Enabling DeepEP MoE implementation for EP MoE.", + ) # Server warmups parser.add_argument( diff --git a/test/srt/test_moe_deepep.py b/test/srt/test_moe_deepep.py new file mode 100644 index 000000000..f89f49810 --- /dev/null +++ b/test/srt/test_moe_deepep.py @@ -0,0 +1,53 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MLA_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestDeepEPMoE(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--tp", + "2", + "--dp", + "2", + "--enable-dp-attention", + "--enable-deepep-moe", + "--disable-cuda-graph", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.5) + + +if __name__ == "__main__": + unittest.main()