diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 814dc469e..dfecb63d9 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -38,7 +38,7 @@ from sglang.srt.layers.quantization.base_config import ( ) from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.model_executor.forward_batch_info import ForwardMode -from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs +from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs _is_cuda = is_cuda() @@ -47,7 +47,6 @@ if _is_cuda: else: from vllm import _custom_ops as vllm_ops - logger = logging.getLogger(__name__) _is_hip = is_hip() @@ -814,7 +813,7 @@ class DeepEPMoE(EPMoE): correction_bias: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, activation: str = "silu", - deepep_mode: str = "auto", + deepep_mode: DeepEPMode = DeepEPMode.auto, ): super().__init__( num_experts, @@ -834,7 +833,7 @@ class DeepEPMoE(EPMoE): activation, ) self.deepep_mode = deepep_mode - if self.deepep_mode in ["low_latency", "auto"]: + if self.deepep_mode.enable_low_latency(): assert use_deep_gemm, f"DeepEP {self.deepep_mode} mode requires deep_gemm" self.w13_weight_fp8 = ( self.w13_weight, @@ -858,13 +857,10 @@ class DeepEPMoE(EPMoE): expected_m: int, forward_mode: ForwardMode, ): - if self.deepep_mode == "normal" or ( - self.deepep_mode == "auto" and not forward_mode.is_decode() - ): + resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) + if resolved_deepep_mode == DeepEPMode.normal: return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr) - elif self.deepep_mode == "low_latency" or ( - self.deepep_mode == "auto" and forward_mode.is_decode() - ): + elif resolved_deepep_mode == DeepEPMode.low_latency: return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m) else: raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}") diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index f4e673535..2a2909816 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -1,3 +1,5 @@ +from sglang.srt.utils import DeepEPMode + try: from deep_ep import Buffer @@ -98,7 +100,7 @@ class DeepEPDispatcher: num_local_experts: int = None, hidden_size: int = None, params_dtype: torch.dtype = None, - deepep_mode: str = "auto", + deepep_mode: DeepEPMode = DeepEPMode.auto, async_finish: bool = False, return_recv_hook: bool = False, ): @@ -120,13 +122,13 @@ class DeepEPDispatcher: self.deepep_mode = deepep_mode self.handle = None - if self.deepep_mode in ["normal", "auto"]: # for normal / auto mode + if self.deepep_mode.enable_normal(): self.buffer_normal = get_buffer_normal( self.group, self.hidden_size * self.params_bytes ) self.async_finish = async_finish self.src2dst = None - if self.deepep_mode in ["low_latency", "auto"]: # for low_latency / auto mode + if self.deepep_mode.enable_low_latency(): """ num_max_dispatch_tokens_per_rank: the actual batch size in the decoding engine should be less than 256 https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding @@ -196,9 +198,8 @@ class DeepEPDispatcher: ) expected_m = 0 - if self.deepep_mode == "normal" or ( - self.deepep_mode == "auto" and not forward_mode.is_decode() - ): + resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) + if resolved_deepep_mode == DeepEPMode.normal: ( hidden_states, topk_idx, @@ -210,9 +211,7 @@ class DeepEPDispatcher: reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute( hidden_states, topk_idx, fp8_dtype=hidden_states.dtype ) - elif self.deepep_mode == "low_latency" or ( - self.deepep_mode == "auto" and forward_mode.is_decode() - ): + elif resolved_deepep_mode == DeepEPMode.low_latency: expected_m = ( hidden_states.shape[0] * self.buffer_low_latency.group_size @@ -354,9 +353,8 @@ class DeepEPDispatcher: topk_weights: torch.Tensor, forward_mode: ForwardMode, ) -> torch.Tensor: - if self.deepep_mode == "normal" or ( - self.deepep_mode == "auto" and not forward_mode.is_decode() - ): + resolved_deepep_mode = self.deepep_mode.resolve(forward_mode) + if resolved_deepep_mode == DeepEPMode.normal: if hidden_states.shape[0] > 0: num_tokens = self.src2dst.shape[0] // self.router_topk output = torch.empty( @@ -384,9 +382,7 @@ class DeepEPDispatcher: output, ) event.current_stream_wait() if self.async_finish else () - elif self.deepep_mode == "low_latency" or ( - self.deepep_mode == "auto" and forward_mode.is_decode() - ): + elif resolved_deepep_mode == DeepEPMode.low_latency: hidden_states, event, hook = self.combine_low_latency( hidden_states, topk_idx, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 775b7413c..2fcd193d8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -70,7 +70,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, is_cuda, is_hip +from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip _is_hip = is_hip() _is_cuda = is_cuda() @@ -215,7 +215,7 @@ class DeepseekV2MoE(nn.Module): topk_group=config.topk_group, correction_bias=self.gate.e_score_correction_bias, prefix=add_prefix("experts", prefix), - deepep_mode=global_server_args_dict["deepep_mode"], + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], ) if config.n_shared_experts is not None: @@ -264,7 +264,7 @@ class DeepseekV2MoE(nn.Module): num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, - deepep_mode=global_server_args_dict["deepep_mode"], + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], async_finish=True, # TODO return_recv_hook=True, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1cd44862f..85f65eb74 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -20,7 +20,7 @@ import logging import os import random import tempfile -from typing import List, Optional +from typing import List, Literal, Optional from sglang.srt.hf_transformers_utils import check_gguf_file from sglang.srt.reasoning_parser import ReasoningParser @@ -161,7 +161,7 @@ class ServerArgs: enable_dp_attention: bool = False enable_ep_moe: bool = False enable_deepep_moe: bool = False - deepep_mode: Optional[str] = "auto" + deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto" enable_torch_compile: bool = False torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 498bc58cc..ba229b1ce 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -37,6 +37,7 @@ import time import traceback import warnings from contextlib import contextmanager +from enum import Enum from functools import lru_cache from importlib.metadata import PackageNotFoundError, version from importlib.util import find_spec @@ -1838,3 +1839,24 @@ def flatten_nested_list(nested_list): ] else: return [nested_list] + + +class DeepEPMode(Enum): + normal = "normal" + low_latency = "low_latency" + auto = "auto" + + def enable_normal(self): + return self in [DeepEPMode.normal, DeepEPMode.auto] + + def enable_low_latency(self): + return self in [DeepEPMode.low_latency, DeepEPMode.auto] + + def resolve(self, forward_mode): + if self != DeepEPMode.auto: + return self + + if forward_mode.is_decode(): + return DeepEPMode.low_latency + else: + return DeepEPMode.normal