[5/N] MoE Refactor: Update MoE parallelism arguments (#8658)

This commit is contained in:
Cheng Wan
2025-08-01 01:20:03 -07:00
committed by GitHub
parent c8d3a402c1
commit 6c88f6c8d9
38 changed files with 342 additions and 299 deletions

View File

@@ -288,12 +288,14 @@ class _SinglePassGatherer(ABC):
)
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
if server_args.moe_a2a_backend is not None and (
server_args.deepep_mode == "normal"
):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.enable_deepep_moe:
if server_args.moe_a2a_backend is not None:
if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":

View File

@@ -108,7 +108,7 @@ class LayerScatterModes:
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
if global_server_args_dict["enable_deepep_moe"]
if not global_server_args_dict["moe_a2a_backend"].is_standard()
else ScatterMode.FULL
)
else:

View File

@@ -1,28 +1,17 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Optional
import torch
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
moe_ep_deepgemm_preprocess,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
pre_reorder_triton_kernel_for_cutlass_moe,
run_cutlass_moe_ep_preproess,
run_moe_ep_preproess,
silu_and_mul_masked_post_quant_fwd,
silu_and_mul_triton_kernel,
tma_align_input_scale,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import (
@@ -31,11 +20,9 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import (
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import (
Fp8Config,
Fp8MoEMethod,
@@ -44,23 +31,13 @@ from sglang.srt.layers.quantization.fp8 import (
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
sglang_per_token_quant_fp8,
)
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
DeepEPMode,
ceil_div,
dispose_tensor,
get_bool_env_var,
is_hip,
is_npu,
)
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
from sglang.srt.layers.moe.token_dispatcher import (
DeepEPLLOutput,
DeepEPNormalOutput,
DispatchOutput,
@@ -119,7 +96,6 @@ class EPMoE(FusedMoE):
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
enable_ep_moe=True,
)
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
@@ -328,7 +304,7 @@ class DeepEPMoE(EPMoE):
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.auto,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
):
super().__init__(
num_experts=num_experts,
@@ -348,7 +324,6 @@ class DeepEPMoE(EPMoE):
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
@@ -762,11 +737,10 @@ class FlashInferEPMoE(EPMoE):
def get_moe_impl_class():
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
return DeepEPMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
return FusedMoE
if global_server_args_dict["enable_ep_moe"]:
if get_moe_expert_parallel_world_size() > 1:
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE

View File

@@ -14,8 +14,6 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
@@ -94,7 +92,6 @@ class FusedMoE(torch.nn.Module):
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
enable_ep_moe: Optional[bool] = False,
):
super().__init__()
@@ -112,7 +109,6 @@ class FusedMoE(torch.nn.Module):
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False
enable_ep_moe = False
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
self.moe_ep_size = get_moe_expert_parallel_world_size()
@@ -121,7 +117,7 @@ class FusedMoE(torch.nn.Module):
self.moe_tp_rank = get_moe_tensor_parallel_rank()
assert num_experts % self.moe_ep_size == 0
self.num_local_experts = num_experts // self.moe_ep_size
if enable_ep_moe:
if self.moe_ep_size > 1:
# TODO(ch-wan): support shared experts fusion
# Create a tensor of size num_experts filled with -1
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)

View File

@@ -0,0 +1,23 @@
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.token_dispatcher.deepep import (
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLOutput,
DeepEPNormalOutput,
)
__all__ = [
"BaseDispatcher",
"BaseDispatcherConfig",
"DispatchOutput",
"DispatchOutputFormat",
"DeepEPConfig",
"DeepEPDispatcher",
"DeepEPNormalOutput",
"DeepEPLLOutput",
]

View File

@@ -2,11 +2,22 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
from typing import Protocol, runtime_checkable
import torch
class MoEA2ABackend(Enum):
none = "none"
deepep = "deepep"
def is_none(self):
return self == MoEA2ABackend.none
def is_deepep(self):
return self == MoEA2ABackend.deepep
class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto()

View File

@@ -1,5 +1,3 @@
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
from __future__ import annotations
import logging
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
DeepEPMode,
get_bool_env_var,
get_int_env_var,
is_hip,
load_json_config,
)
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
try:
from deep_ep import Buffer, Config
@@ -150,9 +143,9 @@ class DeepEPBuffer:
num_rdma_bytes,
)
if deepep_mode == DeepEPMode.normal:
if deepep_mode == DeepEPMode.NORMAL:
num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
elif deepep_mode in [DeepEPMode.low_latency, DeepEPMode.auto]:
elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
num_qps_per_rank = num_experts // group.size()
else:
raise NotImplementedError
@@ -161,7 +154,7 @@ class DeepEPBuffer:
device="cuda"
).multi_processor_count
if (
(deepep_mode != DeepEPMode.low_latency)
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
num_local_experts: int = None,
hidden_size: int = None,
params_dtype: torch.dtype = None,
deepep_mode: DeepEPMode = DeepEPMode.auto,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
async_finish: bool = False,
return_recv_hook: bool = False,
):
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
resolved_deepep_mode = self.deepep_mode.resolve(
forward_batch.is_extend_in_batch
)
if resolved_deepep_mode == DeepEPMode.normal:
if resolved_deepep_mode == DeepEPMode.NORMAL:
return self._normal_dispatcher
elif resolved_deepep_mode == DeepEPMode.low_latency:
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
return self._low_latency_dispatcher
else:
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")

View File

@@ -0,0 +1,43 @@
from enum import Enum
class MoeA2ABackend(Enum):
STANDARD = ("standard", "none")
DEEPEP = "deepep"
@classmethod
def _missing_(cls, value):
if value is None:
return cls.STANDARD
for member in cls:
if value in member.value:
return member
raise ValueError(f"No {cls.__name__} member for value {value}")
def is_deepep(self):
return self == MoeA2ABackend.DEEPEP
def is_standard(self):
return self == MoeA2ABackend.STANDARD
class DeepEPMode(Enum):
NORMAL = "normal"
LOW_LATENCY = "low_latency"
AUTO = "auto"
def enable_normal(self):
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
def enable_low_latency(self):
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
def resolve(self, is_extend_in_batch: bool):
if self != DeepEPMode.AUTO:
return self
if is_extend_in_batch:
return DeepEPMode.NORMAL
else:
return DeepEPMode.LOW_LATENCY

View File

@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_dp_attention",
"enable_two_batch_overlap",
"enable_dp_lm_head",
"enable_deepep_moe",
"moe_a2a_backend",
"deepep_mode",
"enable_ep_moe",
"enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion",

View File

@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
broadcast_pyobj,
configure_gc_logger,
@@ -1762,8 +1762,10 @@ class Scheduler(
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
enable_deepep_moe=MoeA2ABackend(
self.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)

View File

@@ -38,6 +38,7 @@ import torch
import triton
import triton.language as tl
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.dp_attention import (
DPPaddingMode,
get_attention_dp_rank,
@@ -839,7 +840,7 @@ class ForwardBatch:
def enable_num_token_non_padded(server_args):
return server_args.enable_ep_moe or server_args.enable_deepep_moe
return get_moe_expert_parallel_world_size() > 1
class PPProxyTensors:

View File

@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.quantization import (
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
@@ -217,6 +218,10 @@ class ModelRunner:
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
| {
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
)
# CPU offload

View File

@@ -29,6 +29,7 @@ from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
@@ -61,7 +62,6 @@ from sglang.srt.layers.moe.ep_moe.layer import (
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -96,7 +96,6 @@ from sglang.srt.two_batch_overlap import (
)
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
LazyValue,
add_prefix,
bind_or_assign,
@@ -333,15 +332,14 @@ class DeepseekV2MoE(nn.Module):
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["enable_deepep_moe"]
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
)
@@ -404,9 +402,9 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
@@ -428,12 +426,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
def get_moe_weights(self):
return [
@@ -2104,11 +2102,8 @@ class DeepseekV2ForCausalLM(nn.Module):
or self.config.n_shared_experts != 1
):
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif (
global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True

View File

@@ -23,6 +23,7 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
@@ -50,7 +51,6 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import (
DeepEPMoE,
get_moe_impl_class,
should_use_flashinfer_trtllm_moe,
)
@@ -83,7 +83,6 @@ from sglang.srt.two_batch_overlap import (
)
from sglang.srt.utils import (
BumpAllocator,
DeepEPMode,
LazyValue,
add_prefix,
bind_or_assign,
@@ -443,15 +442,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
@@ -484,7 +482,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["enable_deepep_moe"]
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
)
@@ -502,9 +500,9 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.n_routed_experts
+ global_server_args_dict["ep_num_redundant_experts"]
@@ -526,12 +524,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
deepep_mode=global_server_args_dict["deepep_mode"],
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
@@ -737,11 +735,8 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
or self.config.n_shared_experts != 1
):
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif (
global_server_args_dict["enable_deepep_moe"]
or global_server_args_dict["enable_ep_moe"]
):
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
elif get_moe_expert_parallel_world_size() > 1:
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if disable_reason is not None:
global_server_args_dict["disable_shared_experts_fusion"] = True

View File

@@ -29,6 +29,7 @@ from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
@@ -117,7 +118,7 @@ class Grok1MoE(nn.Module):
)
kwargs = {}
if global_server_args_dict["enable_ep_moe"]:
if get_moe_expert_parallel_world_size() > 1:
MoEImpl = EPMoE
else:
MoEImpl = FusedMoE
@@ -616,8 +617,7 @@ class Grok1ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",

View File

@@ -24,6 +24,7 @@ from torch import nn
from transformers import MixtralConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
@@ -94,7 +95,7 @@ class MixtralMoE(nn.Module):
renormalize=True,
)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
self.experts = MoEImpl(
num_experts=num_experts,
top_k=top_k,
@@ -398,8 +399,7 @@ class MixtralForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",

View File

@@ -148,7 +148,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
**(
dict(
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
@@ -616,9 +615,7 @@ class Qwen2MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
expert_params_mapping = MoEImpl.make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",

View File

@@ -24,6 +24,7 @@ import torch
from torch import nn
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -51,7 +52,6 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
@@ -72,7 +72,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
Qwen3MoeConfig = None
@@ -113,15 +113,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
@@ -136,9 +135,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix=add_prefix("gate", prefix),
)
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
@@ -148,7 +147,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
if not global_server_args_dict["enable_deepep_moe"]:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_batch)

View File

@@ -146,7 +146,7 @@ class Step3TextMoEMLP(nn.Module):
prefix=add_prefix("gate", prefix),
)
if global_server_args_dict["enable_deepep_moe"]:
if global_server_args_dict["moe_a2a_backend"].is_deepep():
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@@ -4,7 +4,7 @@ from typing import List, Optional
import torch
from sglang.srt import operations
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPConfig
from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.operations import Operation

View File

@@ -172,12 +172,11 @@ class ServerArgs:
# Expert parallelism
ep_size: int = 1
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
moe_a2a_backend: Optional[Literal["deepep"]] = None
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
init_expert_location: str = "trivial"
@@ -272,7 +271,27 @@ class ServerArgs:
enable_pdmux: bool = False
sm_group_num: int = 3
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
def __post_init__(self):
# Check deprecated arguments
def print_deprecated_warning(message: str):
logger.warning(f"\033[33m{message}\033[0m")
if self.enable_ep_moe:
self.ep_size = self.tp_size
print_deprecated_warning(
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
)
if self.enable_deepep_moe:
self.moe_a2a_backend = "deepep"
print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
# Set missing default values
if self.tokenizer_path is None:
self.tokenizer_path = self.model_path
@@ -455,14 +474,13 @@ class ServerArgs:
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
if self.enable_ep_moe:
self.ep_size = self.tp_size
logger.warning(
f"Flashinfer cutlass MoE and EP MoE are enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
assert self.ep_size in [
1,
self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size"
# DeepEP MoE
if self.enable_deepep_moe:
if self.moe_a2a_backend == "deepep":
if self.deepep_mode == "normal":
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
self.disable_cuda_graph = True
@@ -486,7 +504,7 @@ class ServerArgs:
)
if self.enable_eplb:
assert self.enable_ep_moe or self.enable_deepep_moe
assert self.ep_size > 1 or self.moe_a2a_backend is not None
if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None
@@ -1354,30 +1372,27 @@ class ServerArgs:
help="The expert parallelism size.",
)
parser.add_argument(
"--enable-ep-moe",
action="store_true",
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
"--moe-a2a-backend",
type=str,
choices=["deepep"],
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
action="store_true",
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
)
parser.add_argument(
"--enable-deepep-moe",
action="store_true",
help="Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--deepep-mode",
type=str,
@@ -1839,6 +1854,18 @@ class ServerArgs:
help="Disable mmap while loading weight using safetensors.",
)
# Deprecated arguments
parser.add_argument(
"--enable-ep-moe",
action="store_true",
help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
)
parser.add_argument(
"--enable-deepep-moe",
action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size

View File

@@ -13,17 +13,18 @@ from sglang.srt.layers.communicator import (
CommunicateSummableTensorPairFn,
ScatterMode,
)
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
from sglang.srt.utils import BumpAllocator, get_bool_env_var
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
@@ -310,7 +311,7 @@ class TboDPAttentionPreparer:
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.low_latency)
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
)
else:
self.local_tbo_split_seq_index = 0

View File

@@ -2205,27 +2205,6 @@ def flatten_nested_list(nested_list):
return [nested_list]
class DeepEPMode(Enum):
normal = "normal"
low_latency = "low_latency"
auto = "auto"
def enable_normal(self):
return self in [DeepEPMode.normal, DeepEPMode.auto]
def enable_low_latency(self):
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
def resolve(self, is_extend_in_batch: bool):
if self != DeepEPMode.auto:
return self
if is_extend_in_batch:
return DeepEPMode.normal
else:
return DeepEPMode.low_latency
def is_non_idle_and_non_empty(forward_mode, hidden_states):
return (
(forward_mode is not None)
@@ -2414,7 +2393,7 @@ def require_mlp_tp_gather(server_args):
return True
elif not server_args.enable_dp_lm_head:
return True
elif not server_args.enable_deepep_moe:
elif server_args.moe_a2a_backend is None:
return True
else:
return (
@@ -2430,7 +2409,7 @@ def require_attn_tp_gather(server_args):
Check if the input of attention is scattered.
"""
assert server_args.moe_dense_tp_size in [1, None]
if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size
else:

View File

@@ -499,7 +499,6 @@ class SRTRunner:
chunked_prefill_size: Optional[int] = None,
dp_size: int = 1,
tokenizer_path: Optional[str] = None,
enable_ep_moe: bool = False,
mem_fraction_static: float = 0.65,
trust_remote_code: bool = False,
speculative_draft_model_path: Optional[str] = None,
@@ -550,7 +549,6 @@ class SRTRunner:
enable_dp_attention=enable_dp_attention,
dp_size=dp_size,
tokenizer_path=tokenizer_path,
enable_ep_moe=enable_ep_moe,
disable_overlap_schedule=disable_overlap_schedule,
cuda_graph_max_bs=cuda_graph_max_bs,
disable_custom_all_reduce=disable_custom_all_reduce,