[5/N] MoE Refactor: Update MoE parallelism arguments (#8658)
This commit is contained in:
@@ -288,12 +288,14 @@ class _SinglePassGatherer(ABC):
|
||||
)
|
||||
|
||||
if server_args.expert_distribution_recorder_mode == "stat_approx":
|
||||
if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
|
||||
if server_args.moe_a2a_backend is not None and (
|
||||
server_args.deepep_mode == "normal"
|
||||
):
|
||||
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if server_args.enable_deepep_moe:
|
||||
if server_args.moe_a2a_backend is not None:
|
||||
if server_args.deepep_mode == "normal":
|
||||
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
|
||||
elif server_args.deepep_mode == "low_latency":
|
||||
|
||||
@@ -108,7 +108,7 @@ class LayerScatterModes:
|
||||
if context.is_layer_sparse:
|
||||
return (
|
||||
ScatterMode.SCATTERED
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
if not global_server_args_dict["moe_a2a_backend"].is_standard()
|
||||
else ScatterMode.FULL
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1,28 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
gelu_and_mul_triton_kernel,
|
||||
grouped_gemm_triton,
|
||||
moe_ep_deepgemm_preprocess,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel_for_cutlass_moe,
|
||||
run_cutlass_moe_ep_preproess,
|
||||
run_moe_ep_preproess,
|
||||
silu_and_mul_masked_post_quant_fwd,
|
||||
silu_and_mul_triton_kernel,
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
@@ -31,11 +20,9 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8 import (
|
||||
Fp8Config,
|
||||
Fp8MoEMethod,
|
||||
@@ -44,23 +31,13 @@ from sglang.srt.layers.quantization.fp8 import (
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
is_fp8_fnuz,
|
||||
sglang_per_token_group_quant_fp8,
|
||||
sglang_per_token_quant_fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
ceil_div,
|
||||
dispose_tensor,
|
||||
get_bool_env_var,
|
||||
is_hip,
|
||||
is_npu,
|
||||
)
|
||||
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
|
||||
from sglang.srt.layers.moe.token_dispatcher import (
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
DispatchOutput,
|
||||
@@ -119,7 +96,6 @@ class EPMoE(FusedMoE):
|
||||
activation=activation,
|
||||
# apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
enable_ep_moe=True,
|
||||
)
|
||||
|
||||
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
||||
@@ -328,7 +304,7 @@ class DeepEPMoE(EPMoE):
|
||||
prefix: str = "",
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
||||
):
|
||||
super().__init__(
|
||||
num_experts=num_experts,
|
||||
@@ -348,7 +324,6 @@ class DeepEPMoE(EPMoE):
|
||||
|
||||
# TODO: move to the beginning of the file
|
||||
from sglang.srt.distributed.parallel_state import get_tp_group
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
@@ -762,11 +737,10 @@ class FlashInferEPMoE(EPMoE):
|
||||
|
||||
|
||||
def get_moe_impl_class():
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
return DeepEPMoE
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
||||
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
||||
return FusedMoE
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
||||
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|
||||
|
||||
@@ -14,8 +14,6 @@ from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_moe_tensor_parallel_rank,
|
||||
get_moe_tensor_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
@@ -94,7 +92,6 @@ class FusedMoE(torch.nn.Module):
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
||||
enable_ep_moe: Optional[bool] = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -112,7 +109,6 @@ class FusedMoE(torch.nn.Module):
|
||||
if enable_flashinfer_cutlass_moe and quant_config is None:
|
||||
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||
enable_flashinfer_cutlass_moe = False
|
||||
enable_ep_moe = False
|
||||
|
||||
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
||||
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
||||
@@ -121,7 +117,7 @@ class FusedMoE(torch.nn.Module):
|
||||
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
||||
assert num_experts % self.moe_ep_size == 0
|
||||
self.num_local_experts = num_experts // self.moe_ep_size
|
||||
if enable_ep_moe:
|
||||
if self.moe_ep_size > 1:
|
||||
# TODO(ch-wan): support shared experts fusion
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
BaseDispatcher,
|
||||
BaseDispatcherConfig,
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
||||
DeepEPConfig,
|
||||
DeepEPDispatcher,
|
||||
DeepEPLLOutput,
|
||||
DeepEPNormalOutput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseDispatcher",
|
||||
"BaseDispatcherConfig",
|
||||
"DispatchOutput",
|
||||
"DispatchOutputFormat",
|
||||
"DeepEPConfig",
|
||||
"DeepEPDispatcher",
|
||||
"DeepEPNormalOutput",
|
||||
"DeepEPLLOutput",
|
||||
]
|
||||
|
||||
@@ -2,11 +2,22 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, NamedTuple, Protocol, runtime_checkable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MoEA2ABackend(Enum):
|
||||
none = "none"
|
||||
deepep = "deepep"
|
||||
|
||||
def is_none(self):
|
||||
return self == MoEA2ABackend.none
|
||||
|
||||
def is_deepep(self):
|
||||
return self == MoEA2ABackend.deepep
|
||||
|
||||
|
||||
class DispatchOutputFormat(Enum):
|
||||
standard = auto()
|
||||
deepep_normal = auto()
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
||||
DispatchOutput,
|
||||
DispatchOutputFormat,
|
||||
)
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
get_bool_env_var,
|
||||
get_int_env_var,
|
||||
is_hip,
|
||||
load_json_config,
|
||||
)
|
||||
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
|
||||
|
||||
try:
|
||||
from deep_ep import Buffer, Config
|
||||
@@ -150,9 +143,9 @@ class DeepEPBuffer:
|
||||
num_rdma_bytes,
|
||||
)
|
||||
|
||||
if deepep_mode == DeepEPMode.normal:
|
||||
if deepep_mode == DeepEPMode.NORMAL:
|
||||
num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
|
||||
elif deepep_mode in [DeepEPMode.low_latency, DeepEPMode.auto]:
|
||||
elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
|
||||
num_qps_per_rank = num_experts // group.size()
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -161,7 +154,7 @@ class DeepEPBuffer:
|
||||
device="cuda"
|
||||
).multi_processor_count
|
||||
if (
|
||||
(deepep_mode != DeepEPMode.low_latency)
|
||||
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
||||
and not global_server_args_dict["enable_two_batch_overlap"]
|
||||
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
||||
):
|
||||
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
||||
num_local_experts: int = None,
|
||||
hidden_size: int = None,
|
||||
params_dtype: torch.dtype = None,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
||||
async_finish: bool = False,
|
||||
return_recv_hook: bool = False,
|
||||
):
|
||||
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
|
||||
resolved_deepep_mode = self.deepep_mode.resolve(
|
||||
forward_batch.is_extend_in_batch
|
||||
)
|
||||
if resolved_deepep_mode == DeepEPMode.normal:
|
||||
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
||||
return self._normal_dispatcher
|
||||
elif resolved_deepep_mode == DeepEPMode.low_latency:
|
||||
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
||||
return self._low_latency_dispatcher
|
||||
else:
|
||||
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
||||
43
python/sglang/srt/layers/moe/utils.py
Normal file
43
python/sglang/srt/layers/moe/utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MoeA2ABackend(Enum):
|
||||
|
||||
STANDARD = ("standard", "none")
|
||||
DEEPEP = "deepep"
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value):
|
||||
if value is None:
|
||||
return cls.STANDARD
|
||||
for member in cls:
|
||||
if value in member.value:
|
||||
return member
|
||||
raise ValueError(f"No {cls.__name__} member for value {value}")
|
||||
|
||||
def is_deepep(self):
|
||||
return self == MoeA2ABackend.DEEPEP
|
||||
|
||||
def is_standard(self):
|
||||
return self == MoeA2ABackend.STANDARD
|
||||
|
||||
|
||||
class DeepEPMode(Enum):
|
||||
NORMAL = "normal"
|
||||
LOW_LATENCY = "low_latency"
|
||||
AUTO = "auto"
|
||||
|
||||
def enable_normal(self):
|
||||
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
||||
|
||||
def enable_low_latency(self):
|
||||
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
||||
|
||||
def resolve(self, is_extend_in_batch: bool):
|
||||
if self != DeepEPMode.AUTO:
|
||||
return self
|
||||
|
||||
if is_extend_in_batch:
|
||||
return DeepEPMode.NORMAL
|
||||
else:
|
||||
return DeepEPMode.LOW_LATENCY
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
||||
from sglang.srt.mem_cache.allocator import (
|
||||
BaseTokenToKVPoolAllocator,
|
||||
SWATokenToKVPoolAllocator,
|
||||
@@ -85,9 +86,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
||||
"enable_dp_attention",
|
||||
"enable_two_batch_overlap",
|
||||
"enable_dp_lm_head",
|
||||
"enable_deepep_moe",
|
||||
"moe_a2a_backend",
|
||||
"deepep_mode",
|
||||
"enable_ep_moe",
|
||||
"enable_flashinfer_cutlass_moe",
|
||||
"enable_flashinfer_trtllm_moe",
|
||||
"enable_flashinfer_allreduce_fusion",
|
||||
|
||||
@@ -64,6 +64,7 @@ from sglang.srt.hf_transformers_utils import (
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
CloseSessionReqInput,
|
||||
@@ -137,7 +138,6 @@ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
|
||||
from sglang.srt.utils import (
|
||||
DeepEPMode,
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
configure_gc_logger,
|
||||
@@ -1762,8 +1762,10 @@ class Scheduler(
|
||||
spec_algorithm=self.spec_algorithm,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
|
||||
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
||||
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
||||
enable_deepep_moe=MoeA2ABackend(
|
||||
self.server_args.moe_a2a_backend
|
||||
).is_deepep(),
|
||||
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
|
||||
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
||||
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
DPPaddingMode,
|
||||
get_attention_dp_rank,
|
||||
@@ -839,7 +840,7 @@ class ForwardBatch:
|
||||
|
||||
|
||||
def enable_num_token_non_padded(server_args):
|
||||
return server_args.enable_ep_moe or server_args.enable_deepep_moe
|
||||
return get_moe_expert_parallel_world_size() > 1
|
||||
|
||||
|
||||
class PPProxyTensors:
|
||||
|
||||
@@ -60,6 +60,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
||||
from sglang.srt.layers.quantization import (
|
||||
deep_gemm_wrapper,
|
||||
monkey_patch_isinstance_for_vllm_base_layer,
|
||||
@@ -217,6 +218,10 @@ class ModelRunner:
|
||||
"use_mla_backend": self.use_mla_backend,
|
||||
"speculative_algorithm": self.spec_algorithm,
|
||||
}
|
||||
| {
|
||||
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
|
||||
"deepep_mode": DeepEPMode(server_args.deepep_mode),
|
||||
}
|
||||
)
|
||||
|
||||
# CPU offload
|
||||
|
||||
@@ -29,6 +29,7 @@ from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_tensor_model_parallel_world_size,
|
||||
parallel_state,
|
||||
tensor_model_parallel_all_reduce,
|
||||
@@ -61,7 +62,6 @@ from sglang.srt.layers.moe.ep_moe.layer import (
|
||||
get_moe_impl_class,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -96,7 +96,6 @@ from sglang.srt.two_batch_overlap import (
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
DeepEPMode,
|
||||
LazyValue,
|
||||
add_prefix,
|
||||
bind_or_assign,
|
||||
@@ -333,15 +332,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
@@ -374,7 +372,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
@@ -404,9 +402,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_tensor_model_parallel_world_size()
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.n_routed_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"]
|
||||
@@ -428,12 +426,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
|
||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
@@ -2104,11 +2102,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
or self.config.n_shared_experts != 1
|
||||
):
|
||||
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
||||
elif (
|
||||
global_server_args_dict["enable_deepep_moe"]
|
||||
or global_server_args_dict["enable_ep_moe"]
|
||||
):
|
||||
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
||||
elif get_moe_expert_parallel_world_size() > 1:
|
||||
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
||||
|
||||
if disable_reason is not None:
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||
|
||||
@@ -23,6 +23,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
parallel_state,
|
||||
@@ -50,7 +51,6 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import (
|
||||
DeepEPMoE,
|
||||
get_moe_impl_class,
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
@@ -83,7 +83,6 @@ from sglang.srt.two_batch_overlap import (
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
BumpAllocator,
|
||||
DeepEPMode,
|
||||
LazyValue,
|
||||
add_prefix,
|
||||
bind_or_assign,
|
||||
@@ -443,15 +442,14 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
@@ -484,7 +482,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
prefix=add_prefix("shared_experts", prefix),
|
||||
**(
|
||||
dict(tp_rank=0, tp_size=1)
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
)
|
||||
@@ -502,9 +500,9 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_tensor_model_parallel_world_size()
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.n_routed_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"]
|
||||
@@ -526,12 +524,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
num_local_experts=config.n_routed_experts // self.tp_size,
|
||||
hidden_size=config.hidden_size,
|
||||
params_dtype=config.torch_dtype,
|
||||
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
||||
deepep_mode=global_server_args_dict["deepep_mode"],
|
||||
async_finish=True,
|
||||
return_recv_hook=True,
|
||||
)
|
||||
|
||||
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
|
||||
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
|
||||
|
||||
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
|
||||
@@ -737,11 +735,8 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
||||
or self.config.n_shared_experts != 1
|
||||
):
|
||||
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
||||
elif (
|
||||
global_server_args_dict["enable_deepep_moe"]
|
||||
or global_server_args_dict["enable_ep_moe"]
|
||||
):
|
||||
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
||||
elif get_moe_expert_parallel_world_size() > 1:
|
||||
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
||||
|
||||
if disable_reason is not None:
|
||||
global_server_args_dict["disable_shared_experts_fusion"] = True
|
||||
|
||||
@@ -29,6 +29,7 @@ from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
@@ -117,7 +118,7 @@ class Grok1MoE(nn.Module):
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
if get_moe_expert_parallel_world_size() > 1:
|
||||
MoEImpl = EPMoE
|
||||
else:
|
||||
MoEImpl = FusedMoE
|
||||
@@ -616,8 +617,7 @@ class Grok1ForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
|
||||
@@ -24,6 +24,7 @@ from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
@@ -94,7 +95,7 @@ class MixtralMoE(nn.Module):
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
@@ -398,8 +399,7 @@ class MixtralForCausalLM(nn.Module):
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="w1",
|
||||
ckpt_down_proj_name="w2",
|
||||
ckpt_up_proj_name="w3",
|
||||
|
||||
@@ -148,7 +148,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
@@ -616,9 +615,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
|
||||
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -51,7 +52,6 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -72,7 +72,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
||||
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_non_idle_and_non_empty
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
|
||||
|
||||
Qwen3MoeConfig = None
|
||||
|
||||
@@ -113,15 +113,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
|
||||
if global_server_args_dict["enable_deepep_moe"]
|
||||
dict(deepep_mode=global_server_args_dict["deepep_mode"])
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep()
|
||||
else {}
|
||||
),
|
||||
# Additional args for FusedMoE
|
||||
**(
|
||||
dict(
|
||||
enable_flashinfer_cutlass_moe=True,
|
||||
enable_ep_moe=global_server_args_dict["enable_ep_moe"],
|
||||
)
|
||||
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
||||
else {}
|
||||
@@ -136,9 +135,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
# TODO: we will support tp < ep in the future
|
||||
self.ep_size = get_tensor_model_parallel_world_size()
|
||||
self.ep_size = get_moe_expert_parallel_world_size()
|
||||
self.num_experts = (
|
||||
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
@@ -148,7 +147,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
return self.forward_normal(hidden_states)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
@@ -146,7 +146,7 @@ class Step3TextMoEMLP(nn.Module):
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
if global_server_args_dict["enable_deepep_moe"]:
|
||||
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt import operations
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPConfig
|
||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.operations import Operation
|
||||
|
||||
|
||||
@@ -172,12 +172,11 @@ class ServerArgs:
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
moe_a2a_backend: Optional[Literal["deepep"]] = None
|
||||
enable_flashinfer_cutlass_moe: bool = False
|
||||
enable_flashinfer_trtllm_moe: bool = False
|
||||
enable_flashinfer_allreduce_fusion: bool = False
|
||||
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
||||
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
||||
ep_num_redundant_experts: int = 0
|
||||
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
||||
init_expert_location: str = "trivial"
|
||||
@@ -272,7 +271,27 @@ class ServerArgs:
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
|
||||
# Deprecated arguments
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
# Check deprecated arguments
|
||||
def print_deprecated_warning(message: str):
|
||||
logger.warning(f"\033[33m{message}\033[0m")
|
||||
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
|
||||
)
|
||||
if self.enable_deepep_moe:
|
||||
self.moe_a2a_backend = "deepep"
|
||||
print_deprecated_warning(
|
||||
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
|
||||
)
|
||||
|
||||
# Set missing default values
|
||||
if self.tokenizer_path is None:
|
||||
self.tokenizer_path = self.model_path
|
||||
@@ -455,14 +474,13 @@ class ServerArgs:
|
||||
self.quantization == "modelopt_fp4"
|
||||
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||
if self.enable_ep_moe:
|
||||
self.ep_size = self.tp_size
|
||||
logger.warning(
|
||||
f"Flashinfer cutlass MoE and EP MoE are enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
||||
)
|
||||
assert self.ep_size in [
|
||||
1,
|
||||
self.tp_size,
|
||||
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
||||
|
||||
# DeepEP MoE
|
||||
if self.enable_deepep_moe:
|
||||
if self.moe_a2a_backend == "deepep":
|
||||
if self.deepep_mode == "normal":
|
||||
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
||||
self.disable_cuda_graph = True
|
||||
@@ -486,7 +504,7 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
if self.enable_eplb:
|
||||
assert self.enable_ep_moe or self.enable_deepep_moe
|
||||
assert self.ep_size > 1 or self.moe_a2a_backend is not None
|
||||
|
||||
if self.enable_expert_distribution_metrics and (
|
||||
self.expert_distribution_recorder_mode is None
|
||||
@@ -1354,30 +1372,27 @@ class ServerArgs:
|
||||
help="The expert parallelism size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
action="store_true",
|
||||
help="Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||
"--moe-a2a-backend",
|
||||
type=str,
|
||||
choices=["deepep"],
|
||||
default=ServerArgs.moe_a2a_backend,
|
||||
help="Choose the backend for MoE A2A.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-cutlass-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
||||
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-trtllm-moe",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP with --enable-ep-moe",
|
||||
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-allreduce-fusion",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-deepep-moe",
|
||||
action="store_true",
|
||||
help="Enabling DeepEP MoE implementation for EP MoE.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepep-mode",
|
||||
type=str,
|
||||
@@ -1839,6 +1854,18 @@ class ServerArgs:
|
||||
help="Disable mmap while loading weight using safetensors.",
|
||||
)
|
||||
|
||||
# Deprecated arguments
|
||||
parser.add_argument(
|
||||
"--enable-ep-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-deepep-moe",
|
||||
action="store_true",
|
||||
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
args.tp_size = args.tensor_parallel_size
|
||||
|
||||
@@ -13,17 +13,18 @@ from sglang.srt.layers.communicator import (
|
||||
CommunicateSummableTensorPairFn,
|
||||
ScatterMode,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.utils import DeepEPMode
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import BumpAllocator, DeepEPMode, get_bool_env_var
|
||||
from sglang.srt.utils import BumpAllocator, get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DispatchOutput
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
||||
|
||||
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
||||
|
||||
@@ -310,7 +311,7 @@ class TboDPAttentionPreparer:
|
||||
and not local_batch.forward_mode.is_target_verify()
|
||||
)
|
||||
and enable_deepep_moe
|
||||
and (resolved_deepep_mode == DeepEPMode.low_latency)
|
||||
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
|
||||
)
|
||||
else:
|
||||
self.local_tbo_split_seq_index = 0
|
||||
|
||||
@@ -2205,27 +2205,6 @@ def flatten_nested_list(nested_list):
|
||||
return [nested_list]
|
||||
|
||||
|
||||
class DeepEPMode(Enum):
|
||||
normal = "normal"
|
||||
low_latency = "low_latency"
|
||||
auto = "auto"
|
||||
|
||||
def enable_normal(self):
|
||||
return self in [DeepEPMode.normal, DeepEPMode.auto]
|
||||
|
||||
def enable_low_latency(self):
|
||||
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
|
||||
|
||||
def resolve(self, is_extend_in_batch: bool):
|
||||
if self != DeepEPMode.auto:
|
||||
return self
|
||||
|
||||
if is_extend_in_batch:
|
||||
return DeepEPMode.normal
|
||||
else:
|
||||
return DeepEPMode.low_latency
|
||||
|
||||
|
||||
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||
return (
|
||||
(forward_mode is not None)
|
||||
@@ -2414,7 +2393,7 @@ def require_mlp_tp_gather(server_args):
|
||||
return True
|
||||
elif not server_args.enable_dp_lm_head:
|
||||
return True
|
||||
elif not server_args.enable_deepep_moe:
|
||||
elif server_args.moe_a2a_backend is None:
|
||||
return True
|
||||
else:
|
||||
return (
|
||||
@@ -2430,7 +2409,7 @@ def require_attn_tp_gather(server_args):
|
||||
Check if the input of attention is scattered.
|
||||
"""
|
||||
assert server_args.moe_dense_tp_size in [1, None]
|
||||
if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
|
||||
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
|
||||
if server_args.enable_dp_attention:
|
||||
return server_args.dp_size < server_args.tp_size
|
||||
else:
|
||||
|
||||
@@ -499,7 +499,6 @@ class SRTRunner:
|
||||
chunked_prefill_size: Optional[int] = None,
|
||||
dp_size: int = 1,
|
||||
tokenizer_path: Optional[str] = None,
|
||||
enable_ep_moe: bool = False,
|
||||
mem_fraction_static: float = 0.65,
|
||||
trust_remote_code: bool = False,
|
||||
speculative_draft_model_path: Optional[str] = None,
|
||||
@@ -550,7 +549,6 @@ class SRTRunner:
|
||||
enable_dp_attention=enable_dp_attention,
|
||||
dp_size=dp_size,
|
||||
tokenizer_path=tokenizer_path,
|
||||
enable_ep_moe=enable_ep_moe,
|
||||
disable_overlap_schedule=disable_overlap_schedule,
|
||||
cuda_graph_max_bs=cuda_graph_max_bs,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
|
||||
Reference in New Issue
Block a user