[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

This commit is contained in:
Cheng Wan
2025-08-14 21:14:53 -07:00
committed by GitHub
parent 584e1ab2d0
commit 295895120d
69 changed files with 956 additions and 1037 deletions

View File

@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip,
is_port_available,
is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address,
nullable_str,
)
@@ -175,9 +176,15 @@ class ServerArgs:
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Optional[Literal["deepep"]] = None
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
moe_a2a_backend: Literal["none", "deepep"] = "none"
moe_runner_backend: Literal[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0
@@ -250,8 +257,6 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
scheduler_recv_interval: int = 1
# Debug tensor dumps
@@ -282,6 +287,9 @@ class ServerArgs:
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
def __post_init__(self):
# Check deprecated arguments
@@ -298,6 +306,21 @@ class ServerArgs:
print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
# Set missing default values
if self.tokenizer_path is None:
@@ -517,7 +540,7 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
if self.enable_flashinfer_cutlass_moe:
if self.moe_runner_backend == "flashinfer_cutlass":
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
@@ -527,7 +550,7 @@ class ServerArgs:
self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.enable_flashinfer_trtllm_moe:
if self.moe_runner_backend == "flashinfer_trtllm":
if not self.disable_shared_experts_fusion:
self.disable_shared_experts_fusion = True
logger.warning(
@@ -556,7 +579,7 @@ class ServerArgs:
self.ep_dispatch_algorithm = "static"
if self.enable_eplb:
assert self.ep_size > 1 or self.moe_a2a_backend is not None
assert self.ep_size > 1
if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None
@@ -1446,19 +1469,22 @@ class ServerArgs:
parser.add_argument(
"--moe-a2a-backend",
type=str,
choices=["deepep"],
choices=["none", "deepep"],
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
"--moe-runner-backend",
type=str,
choices=[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
)
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
@@ -1825,11 +1851,6 @@ class ServerArgs:
action="store_true",
help="Enable returning hidden states with responses.",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="Use triton moe grouped gemm kernel.",
)
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
action="store_true",
@@ -1965,6 +1986,21 @@ class ServerArgs:
action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="(Deprecated) Use triton moe grouped gemm kernel.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
@@ -2143,18 +2179,21 @@ class ServerArgs:
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.enable_triton_kernel_moe:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)