[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

This commit is contained in:
Cheng Wan
2025-08-14 21:14:53 -07:00
committed by GitHub
parent 584e1ab2d0
commit 295895120d
69 changed files with 956 additions and 1037 deletions

View File

@@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
model_runner.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
)

View File

@@ -25,7 +25,6 @@ import torch
import torch.distributed
from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var
@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
)
if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.moe_a2a_backend is not None and (
if server_args.moe_a2a_backend != "none" and (
server_args.deepep_mode == "normal"
):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else:
raise NotImplementedError
if server_args.moe_a2a_backend is not None:
if server_args.moe_a2a_backend != "none":
if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency":

View File

@@ -17,7 +17,7 @@ from enum import Enum, auto
from functools import partial
from typing import Dict, Optional
import torch.distributed
import torch
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
@@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer,
get_local_dp_buffer,
)
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -111,7 +112,7 @@ class LayerScatterModes:
if context.is_layer_sparse:
return (
ScatterMode.SCATTERED
if not global_server_args_dict["moe_a2a_backend"].is_standard()
if not get_moe_a2a_backend().is_none()
else ScatterMode.FULL
)
else:

View File

@@ -0,0 +1,29 @@
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.utils import (
DeepEPMode,
MoeA2ABackend,
MoeRunnerBackend,
get_deepep_config,
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
get_tbo_token_distribution_threshold,
initialize_moe_config,
is_tbo_enabled,
should_use_flashinfer_trtllm_moe,
)
__all__ = [
"DeepEPMode",
"MoeA2ABackend",
"MoeRunnerConfig",
"MoeRunnerBackend",
"initialize_moe_config",
"get_moe_a2a_backend",
"get_moe_runner_backend",
"get_deepep_mode",
"should_use_flashinfer_trtllm_moe",
"is_tbo_enabled",
"get_tbo_token_distribution_threshold",
"get_deepep_config",
]

View File

@@ -1,11 +1,17 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union
import torch
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import (
Fp8Config,
Fp8MoEMethod,
get_tile_tokens_dim,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
sglang_per_token_group_quant_fp8,
@@ -89,12 +90,11 @@ class EPMoE(FusedMoE):
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
@@ -106,13 +106,12 @@ class EPMoE(FusedMoE):
top_k=top_k,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
)
@@ -163,7 +162,8 @@ class EPMoE(FusedMoE):
)
assert self.quant_method is not None
assert self.activation == "silu"
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
@@ -327,8 +327,8 @@ class EPMoE(FusedMoE):
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.routed_scaling_factor is not None:
output *= self.routed_scaling_factor
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
@@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
):
super().__init__(
num_experts=num_experts,
@@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=prefix,
activation=activation,
routed_scaling_factor=routed_scaling_factor,
)
self.deepep_mode = deepep_mode
self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
@@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE):
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=deepep_mode,
deepep_mode=self.deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
@@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE):
)
def moe_impl(self, dispatch_output: DispatchOutput):
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output)
if _is_npu:
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
return self.forward_npu(dispatch_output)
if dispatch_output.format.is_deepep_normal():
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output)
elif dispatch_output.format.is_deepep_ll():
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output)
else:
@@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE):
def forward_aiter(
self,
dispatch_output: DeepEPNormalOutput,
dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
):
hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states,
@@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE):
quant_type=QuantType.per_128x128,
activation=(
ActivationType.Silu
if self.activation == "silu"
if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
expert_mask=self.expert_mask,
@@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE):
)
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None
assert self.activation == "silu"
assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
@@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE):
):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.activation == "silu"
assert self.moe_runner_config.activation == "silu"
# GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size()
@@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE):
def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False):
if get_moe_runner_backend().is_flashinfer_trtllm():
try:
# Check the quantization argument directly
quantization = global_server_args_dict.get("quantization")
@@ -803,7 +804,7 @@ def get_moe_impl_class():
if should_use_flashinfer_trtllm_moe():
return FlashInferFusedMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE
if get_moe_expert_parallel_world_size() > 1:
return EPMoE

View File

@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from typing import Callable, Optional
import torch
from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if apply_router_weight_on_input:
if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
if activation == "silu":
if moe_runner_config.activation == "silu":
x1 = F.silu(x1)
elif activation == "gelu":
elif moe_runner_config.activation == "gelu":
x1 = F.gelu(x1)
else:
raise ValueError(f"Unsupported activation: {activation=}")
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
def moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if apply_router_weight_on_input:
if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output
@@ -72,12 +61,12 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy()
if activation == "silu":
if moe_runner_config.activation == "silu":
act = SiluAndMul()
elif activation == "gelu":
elif moe_runner_config.activation == "gelu":
act = GeluAndMul()
else:
raise ValueError(f"Unsupported activation: {activation=}")
raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
outputs = []
start_idx = 0

View File

@@ -2,17 +2,20 @@
"""Fused MoE kernel."""
from __future__ import annotations
import functools
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
scaled_fp8_quant,
@@ -1025,8 +1028,8 @@ def inplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> None:
fused_experts_impl(
hidden_states,
@@ -1053,8 +1056,8 @@ def inplace_fused_experts(
block_shape,
False,
routed_scaling_factor,
activation_alpha,
swiglu_limit,
gemm1_alpha,
gemm1_limit,
)
@@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> None:
pass
@@ -1119,8 +1122,8 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
@@ -1147,8 +1150,8 @@ def outplace_fused_experts(
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
gemm1_alpha=gemm1_alpha,
gemm1_limit=gemm1_limit,
)
@@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -1194,12 +1197,10 @@ def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1212,14 +1213,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
):
topk_weights, topk_ids, _ = topk_output
if inplace:
assert not no_combine, "no combine + inplace makes no sense"
if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
hidden_states,
w1,
@@ -1228,8 +1225,8 @@ def fused_experts(
topk_ids,
b1,
b2,
activation,
apply_router_weight_on_input,
moe_runner_config.activation,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
@@ -1242,9 +1239,9 @@ def fused_experts(
a1_scale,
a2_scale,
block_shape,
routed_scaling_factor,
activation_alpha,
swiglu_limit,
moe_runner_config.routed_scaling_factor,
moe_runner_config.gemm1_alpha,
moe_runner_config.gemm1_clamp_limit,
)
return hidden_states
else:
@@ -1256,8 +1253,8 @@ def fused_experts(
topk_ids,
b1,
b2,
activation,
apply_router_weight_on_input,
moe_runner_config.activation,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
use_int8_w8a16,
@@ -1270,10 +1267,10 @@ def fused_experts(
a1_scale,
a2_scale,
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
no_combine=moe_runner_config.no_combine,
routed_scaling_factor=moe_runner_config.routed_scaling_factor,
gemm1_alpha=moe_runner_config.gemm1_alpha,
gemm1_limit=moe_runner_config.gemm1_clamp_limit,
)
@@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
@torch.compile
def swiglu_with_alpha_and_limit(x, alpha, limit):
def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit)
up = up.clamp(min=-limit, max=limit)
return gate * torch.sigmoid(gate * alpha) * (up + 1)
gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
def fused_experts_impl(
@@ -1402,8 +1399,8 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_limit: Optional[float] = None,
):
padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
@@ -1533,12 +1530,12 @@ def fused_experts_impl(
block_shape=block_shape,
)
if activation == "silu":
if activation_alpha is not None:
assert swiglu_limit is not None
if gemm1_alpha is not None:
assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N),
activation_alpha,
swiglu_limit,
gemm1_alpha,
gemm1_limit,
)
elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
@@ -1547,10 +1544,8 @@ def fused_experts_impl(
intermediate_cache2, intermediate_cache1.view(-1, N)
)
elif activation == "gelu":
assert (
activation_alpha is None
), "activation_alpha is not supported for gelu"
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else:
@@ -1641,12 +1636,10 @@ def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
@@ -1659,10 +1652,6 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1672,11 +1661,10 @@ def fused_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- topk_output (TopKOutput): The top-k output of the experts.
- topk_output (StandardTopKOutput): The top-k output of the experts.
- moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
- b1 (Optional[torch.Tensor]): Optional bias for w1.
- b2 (Optional[torch.Tensor]): Optional bias for w2.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
@@ -1696,9 +1684,9 @@ def fused_moe(
a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization.
- activation_alpha (Optional[float]): Optional alpha for the activation
- gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
function.
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation
- gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
function.
Returns:
@@ -1710,11 +1698,9 @@ def fused_moe(
w1,
w2,
topk_output,
moe_runner_config=moe_runner_config,
b1=b1,
b2=b2,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
@@ -1727,8 +1713,4 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)

View File

@@ -1,10 +1,6 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import datetime
import glob
import logging
import os
import sys
from enum import Enum
from typing import List, Optional, Tuple
@@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.moe import (
MoeRunnerConfig,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
activation: str = "silu",
apply_router_weight_on_input: bool = False,
@@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module):
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
use_weight_loader_fused: bool = False,
with_bias=False,
):
@@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None
self.expert_map_gpu = None
# For activation
self.activation_alpha = activation_alpha
self.swiglu_limit = swiglu_limit
self.moe_runner_config = MoeRunnerConfig(
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
@@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module):
* self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
@@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None
self.quant_config = quant_config
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.use_enable_flashinfer_mxfp4_moe
and self.use_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.quant_method.create_weights(
@@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module):
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
)
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None
@@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module):
# If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None and isinstance(
topk_output, StandardTopKOutput
):
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
if self.expert_map_gpu is not None:
if TopKOutputChecker.format_is_standard(topk_output):
topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids]
)
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError()
# Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm:
kwargs = {}
if self.activation_alpha is not None:
kwargs["activation_alpha"] = self.activation_alpha
if self.swiglu_limit is not None:
kwargs["swiglu_limit"] = self.swiglu_limit
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
topk_output=topk_output,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
**kwargs,
moe_runner_config=self.moe_runner_config,
)
sm.tag(final_hidden_states)
@@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module):
class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, topk_output: tuple):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.use_flashinfer_trtllm_moe
assert (
self.activation == "silu"
@@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE):
self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
assert TopKOutputChecker.format_is_bypassed(topk_output)
# Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self,
x=hidden_states,
router_logits=router_logits,
activation=self.activation,
routed_scaling_factor=self.routed_scaling_factor,
topk_output=topk_output,
moe_runner_config=self.moe_runner_config,
)
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer."""
def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
super().__init__(*args, **kwargs)
# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
# ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass
# ---------------------------------------------------------------------
@@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE):
return hs_fp4, hs_sf
def forward(self, hidden_states: torch.Tensor, topk_output):
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
"""Forward pass using FP4 TRTLLM kernel.
Args:
hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode
topk_output: TopKOutput object with Bypassed format
"""
assert TopKOutputChecker.format_is_bypassed(topk_output)
# TRTLLM mode expects (TopK_config, router_logits) tuple
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
@@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE):
result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype),
routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
@@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE):
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts,
top_k=self.top_k,
n_group=self.num_expert_group,
topk_group=self.topk_group,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor,
routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts
hidden_states.shape[0], topk_config.top_k, self.num_local_experts
),
routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True,

View File

@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
w1: torch.Tensor,
w2: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
moe_runner_config: MoeRunnerConfig,
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
block_shape: Optional[list[int]] = None,
) -> torch.Tensor:
assert topk_output.format.is_triton_kernel()
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts(
@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
routing_data,
gather_idx,
scatter_idx,
inplace=inplace,
activation=activation,
inplace=False, # triton kernel doesn't support inplace
activation=moe_runner_config.activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
w2_pcg,
b2: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
moe_runner_config: MoeRunnerConfig,
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor:
assert topk_output.format.is_triton_kernel()
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts_with_bias(
@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
routing_data=routing_data,
gather_indx=gather_idx,
scatter_indx=scatter_idx,
inplace=inplace,
activation=activation,
inplace=False, # triton kernel doesn't support inplace
activation=moe_runner_config.activation,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
gemm1_alpha=moe_runner_config.gemm1_alpha,
gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
)
@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
) -> torch.Tensor:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported"
@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
(activation_alpha, swiglu_limit),
(gemm1_alpha, gemm1_clamp_limit),
2,
)

View File

@@ -0,0 +1,3 @@
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
__all__ = ["MoeRunnerConfig"]

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class MoeRunnerConfig:
activation: str = "silu"
apply_router_weight_on_input: bool = False
inplace: bool = True
no_combine: bool = False
routed_scaling_factor: Optional[float] = None
gemm1_alpha: Optional[float] = None
gemm1_clamp_limit: Optional[float] = None

View File

@@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputChecker,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput,
DeepEPConfig,
DeepEPDispatcher,
DeepEPLLOutput,
DeepEPNormalOutput,
)
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
__all__ = [
"AscendDeepEPLLOutput",
"BaseDispatcher",
"BaseDispatcherConfig",
"DispatchOutput",
"DispatchOutputFormat",
"DispatchOutputChecker",
"StandardDispatchOutput",
"DeepEPConfig",
"DeepEPDispatcher",
"DeepEPNormalOutput",

View File

@@ -2,35 +2,76 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Protocol, runtime_checkable
from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
StandardDispatchOutput,
)
class MoEA2ABackend(Enum):
none = "none"
deepep = "deepep"
def is_none(self):
return self == MoEA2ABackend.none
class DispatchOutputChecker:
def is_deepep(self):
return self == MoEA2ABackend.deepep
@staticmethod
def format_is_standard(
dispatch_output: DispatchOutput,
) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard()
@staticmethod
def format_is_deepep_normal(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPNormalOutput]:
return dispatch_output.format.is_deepep_normal()
@staticmethod
def format_is_deepep_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPLLOutput]:
return dispatch_output.format.is_deepep_ll()
@staticmethod
def format_is_deepep(
dispatch_output: DispatchOutput,
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
return dispatch_output.format.is_deepep()
@staticmethod
def format_is_ascent_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[AscendDeepEPLLOutput]:
return dispatch_output.format.is_ascent_ll()
class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto()
deepep_ll = auto()
STANDARD = auto()
DEEPEP_NORMAL = auto()
DEEPEP_LL = auto()
ASCENT_LL = auto()
def is_standard(self) -> bool:
return self == DispatchOutputFormat.standard
return self == DispatchOutputFormat.STANDARD
def is_deepep_normal(self) -> bool:
return self == DispatchOutputFormat.deepep_normal
return self == DispatchOutputFormat.DEEPEP_NORMAL
def is_deepep_ll(self) -> bool:
return self == DispatchOutputFormat.deepep_ll
return self == DispatchOutputFormat.DEEPEP_LL
def is_deepep(self) -> bool:
return self in [
DispatchOutputFormat.DEEPEP_NORMAL,
DispatchOutputFormat.DEEPEP_LL,
]
def is_ascent_ll(self) -> bool:
return self == DispatchOutputFormat.ASCENT_LL
@runtime_checkable

View File

@@ -2,27 +2,17 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
List,
NamedTuple,
Optional,
Protocol,
Tuple,
Union,
runtime_checkable,
)
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher,
BaseDispatcherConfig,
DispatchOutput,
DispatchOutputFormat,
)
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
get_bool_env_var,
get_int_env_var,
@@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_normal
return DispatchOutputFormat.DEEPEP_NORMAL
class DeepEPLLOutput(NamedTuple):
@@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
return DispatchOutputFormat.DEEPEP_LL
class AscendDeepEPLLOutput(NamedTuple):
@@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll
return DispatchOutputFormat.ASCENT_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput)
@@ -128,8 +118,8 @@ class DeepEPBuffer:
hidden_size: int,
param_bytes: int,
deepep_mode: DeepEPMode,
num_max_dispatch_tokens_per_rank: int = None,
num_experts: int = None,
num_max_dispatch_tokens_per_rank: int = -1,
num_experts: int = -1,
):
if cls._buffer is not None:
return cls._buffer
@@ -156,8 +146,8 @@ class DeepEPBuffer:
num_rdma_bytes,
)
if deepep_mode.enable_low_latency():
assert num_max_dispatch_tokens_per_rank is not None
assert num_experts is not None and num_experts % group.size() == 0
assert num_max_dispatch_tokens_per_rank != -1
assert num_experts != -1 and num_experts % group.size() == 0
num_rdma_bytes = max(
Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank,
@@ -181,7 +171,7 @@ class DeepEPBuffer:
).multi_processor_count
if (
(deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"]
and not is_tbo_enabled()
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
):
logger.warning(
@@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig):
_instance = None
def __init__(self):
config_str = global_server_args_dict["deepep_config"]
config_str = get_deepep_config()
if config_str:
config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0:

View File

@@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple):
@property
def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.standard
return DispatchOutputFormat.STANDARD
assert isinstance(StandardDispatchOutput, DispatchOutput)

View File

@@ -14,9 +14,18 @@
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
from typing import (
Callable,
NamedTuple,
Optional,
Protocol,
TypeGuard,
runtime_checkable,
)
import torch
import torch.nn.functional as F
@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.layers.moe import (
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
@@ -43,6 +55,7 @@ try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
pass
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
@@ -65,13 +78,48 @@ if _use_aiter:
if _is_npu:
import torch_npu
# -------------------------------- TopKConfig ---------------------------------------
@dataclass
class TopKConfig:
top_k: int
use_grouped_topk: bool = False
topk_group: int = 0
num_expert_group: int = 0
renormalize: bool = True
num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None
correction_bias: Optional[torch.Tensor] = None
torch_native: bool = False
routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False
# -------------------------------- TopKOutput ---------------------------------------
class TopKOutputChecker:
@staticmethod
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
return topk_output.format.is_standard()
@staticmethod
def format_is_triton_kernel(
topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]:
return topk_output.format.is_triton_kernel()
@staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
return topk_output.format.is_bypassed()
class TopKOutputFormat(Enum):
STANDARD = auto()
TRITON_KERNEL = auto()
BYPASSED = auto()
def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD
@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL
def is_bypassed(self) -> bool:
return self == TopKOutputFormat.BYPASSED
@runtime_checkable
class TopKOutput(Protocol):
@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
return TopKOutputFormat.TRITON_KERNEL
class BypassedTopKOutput(NamedTuple):
"""Bypassed top-k output format."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
topk_config: TopKConfig
num_token_non_padded: Optional[torch.Tensor] = None
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.BYPASSED
# -------------------------------- TopK ---------------------------------------
@@ -124,8 +189,8 @@ class TopK(CustomOp):
top_k: int,
*,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
topk_group: int = 0,
num_expert_group: int = 0,
renormalize: bool = True,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
@@ -136,19 +201,23 @@ class TopK(CustomOp):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__()
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.top_k = top_k
self.use_grouped_topk = use_grouped_topk
self.renormalize = renormalize
self.topk_group = topk_group
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
def forward_native(
self,
@@ -158,20 +227,11 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
torch_native = True
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
@@ -187,24 +247,28 @@ class TopK(CustomOp):
if self.use_triton_kernels:
# renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing(
router_logits, self.top_k, sm_first=not self.renormalize
router_logits,
self.topk_config.top_k,
sm_first=not self.topk_config.renormalize,
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
return BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else:
torch_native = False
self.topk_config.torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
@@ -220,15 +284,7 @@ class TopK(CustomOp):
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
@@ -244,35 +300,29 @@ class TopK(CustomOp):
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if global_num_experts == 256 and self.topk_config.renormalize is False:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.top_k,
bias=self.correction_bias.to(torch.float32),
k_group=self.topk_group,
group_count=self.num_expert_group,
k=self.topk_config.top_k,
bias=self.topk_config.correction_bias.to(torch.float32),
k_group=self.topk_config.topk_group,
group_count=self.topk_config.num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20),
)
else:
torch_native = True
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
@@ -670,20 +720,23 @@ else:
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
topk_config: TopKConfig,
*,
use_grouped_topk: bool = False,
renormalize: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
) -> StandardTopKOutput:
top_k = topk_config.top_k
use_grouped_topk = topk_config.use_grouped_topk
topk_group = topk_config.topk_group
num_expert_group = topk_config.num_expert_group
renormalize = topk_config.renormalize
num_fused_shared_experts = topk_config.num_fused_shared_experts
custom_routing_function = topk_config.custom_routing_function
correction_bias = topk_config.correction_bias
torch_native = topk_config.torch_native
routed_scaling_factor = topk_config.routed_scaling_factor
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits,

View File

@@ -1,55 +1,80 @@
from __future__ import annotations
import importlib.util
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from packaging import version as pkg_version
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import logger
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
if TYPE_CHECKING:
from sglang.srt.server_args import ServerArgs
class MoeA2ABackend(Enum):
STANDARD = ("standard", "none")
NONE = "none"
DEEPEP = "deepep"
@classmethod
def _missing_(cls, value):
if value is None:
return cls.STANDARD
return cls.NONE
for member in cls:
if value in member.value:
if value == member.value:
return member
raise ValueError(f"No {cls.__name__} member for value {value}")
def is_none(self):
return self == MoeA2ABackend.NONE
def is_deepep(self):
return self == MoeA2ABackend.DEEPEP
def is_standard(self):
return self == MoeA2ABackend.STANDARD
class MoeRunnerBackend(Enum):
AUTO = "auto"
TRITON = "triton"
TRITON_KERNEL = "triton_kernel"
FLASHINFER = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
def is_auto(self):
return self == MoeRunnerBackend.AUTO
def is_triton(self):
return self == MoeRunnerBackend.TRITON
def is_triton_kernel(self):
return self == MoeRunnerBackend.TRITON_KERNEL
def is_flashinfer_trtllm(self):
return self == MoeRunnerBackend.FLASHINFER
def is_flashinfer_cutlass(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4
class DeepEPMode(Enum):
NORMAL = "normal"
LOW_LATENCY = "low_latency"
AUTO = "auto"
def enable_normal(self):
def enable_normal(self) -> bool:
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
def enable_low_latency(self):
def enable_low_latency(self) -> bool:
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
def resolve(self, is_extend_in_batch: bool):
def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
if self != DeepEPMode.AUTO:
return self
@@ -57,3 +82,96 @@ class DeepEPMode(Enum):
return DeepEPMode.NORMAL
else:
return DeepEPMode.LOW_LATENCY
def is_normal(self) -> bool:
return self == DeepEPMode.NORMAL
def is_low_latency(self) -> bool:
return self == DeepEPMode.LOW_LATENCY
def is_auto(self) -> bool:
return self == DeepEPMode.AUTO
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None
def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend(None)
return MOE_A2A_BACKEND
def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
return MOE_RUNNER_BACKEND
def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
DEEPEP_MODE = DeepEPMode("auto")
return DEEPEP_MODE
def get_deepep_config() -> str:
global DEEPEP_CONFIG
if DEEPEP_CONFIG is None:
logger.warning("DEEPEP_CONFIG is not initialized, using default config")
DEEPEP_CONFIG = ""
return DEEPEP_CONFIG
def is_tbo_enabled() -> bool:
global IS_TBO_ENABLED
if IS_TBO_ENABLED is None:
logger.warning("IS_TBO_ENABLED is not initialized, using False")
IS_TBO_ENABLED = False
return IS_TBO_ENABLED
def get_tbo_token_distribution_threshold() -> float:
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
logger.warning(
"TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
)
TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
return TBO_TOKEN_DISTRIBUTION_THRESHOLD
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result

View File

@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.utils import is_cuda, is_hip
@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
**kwargs,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16
orig_dtype = x.dtype

View File

@@ -9,6 +9,7 @@ import torch
from torch import nn
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
raise NotImplementedError

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn import Module
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL,
@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
)
@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
**kwargs,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
topk_weights, topk_ids, router_logits = topk_output

View File

@@ -41,6 +41,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
logger = logging.getLogger(__name__)
@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase):
return out
class MxFp4MoEMethod:
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
class MxFp4MoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: Mxfp4Config):
self.quant_config = quant_config
@staticmethod
def get_moe_method(
@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)
@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output
@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu
ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
doweight_stage1=False,
)

View File

@@ -79,6 +79,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer,
x,
topk_output,
activation,
no_combine,
moe_runner_config.activation,
moe_runner_config.no_combine,
)
if ret is not None:
return ret
@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
use_fp8_blockscale=True,
)
# TODO: Fuse into select_experts
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
# Expert fusion with FP8 quantization
return fused_experts(
@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def apply_with_router_logits(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
*,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
topk_output: TopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
activation = moe_runner_config.activation
routed_scaling_factor = moe_runner_config.routed_scaling_factor
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
assert (
activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe"
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
# NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous()
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32),
@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts,
top_k=layer.top_k,
n_group=layer.num_expert_group,
topk_group=layer.topk_group,
top_k=topk_config.top_k,
n_group=topk_config.num_expert_group,
topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts,

View File

@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
return weight, weight_scale, input_scale
# TODO(ch-wan): define these backends in --moe-runner-backend
def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False

View File

@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.utils import is_cuda
@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
**kwargs,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# Delay the import to avoid circular dependency
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16
orig_dtype = x.dtype

View File

@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
if TYPE_CHECKING:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
try:
from vllm import _custom_ops as ops
@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
)[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input
supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
# moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu"
supports_activation = layer.moe_runner_config.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition)

View File

@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda():
@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
)
@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property
def enable_flashinfer_cutlass_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutlass_moe setting."""
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False)
return get_moe_runner_backend().is_flashinfer_cutlass()
def create_weights(
self,
@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def apply(
self,
layer: torch.nn.Module,
layer: FusedMoE,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"):
@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if self.enable_flashinfer_cutlass_moe:
assert (
not apply_router_weight_on_input
not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas,
],
ep_size=ep_size,
ep_rank=ep_rank,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=layer.moe_ep_size,
ep_rank=layer.moe_ep_rank,
tp_size=layer.moe_tp_size,
tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0]
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights,
topk_ids=topk_ids,
params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input,
apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype)
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output

View File

@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."
assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.w13_qweight,
layer.w2_qweight,
topk_output=topk_output,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales,
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
@staticmethod
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
)
if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
tp_rank
]
tensor = loaded_weight.view(
layer.moe_tp_size, -1, loaded_weight.size(1)
)[tp_rank]
if shard_id == "w1":
param.data[expert_id, : shard_size // 2] = tensor
else:
param.data[expert_id, shard_size // 2 :] = tensor
elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1
loaded_weight.size(0), layer.moe_tp_size, -1
)[:, tp_rank]
else:
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)

View File

@@ -16,14 +16,13 @@
from __future__ import annotations
import importlib.util
import logging
from typing import TYPE_CHECKING, List, Optional
import torch
import triton.language as tl
from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase,
QuantizationConfig,
@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
)
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
@@ -60,6 +58,7 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32
@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self,
prefix: str,
):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__()
self.prefix = prefix
self.topk_indices_dtype = None
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"]
self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logger,
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
)
# TODO: these values are hardcoded for now, we need to get them from the model
layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False,
@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
moe_runner_config=moe_runner_config,
)
else:
return self.triton_kernel_moe_forward(
@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import importlib
from typing import TYPE_CHECKING, Callable, List, Optional
import importlib.util
from typing import TYPE_CHECKING, List, Optional
import torch
import torch.nn.functional as F
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
kwargs = {}
if activation_alpha is not None:
kwargs["activation_alpha"] = activation_alpha
if swiglu_limit is not None:
kwargs["swiglu_limit"] = swiglu_limit
return self.forward(
x=x,
layer=layer,
topk_output=topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
**kwargs,
moe_runner_config=moe_runner_config,
)
def forward_cuda(
@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
if self.use_triton_kernels:
if self.with_bias:
assert self.triton_kernel_moe_with_bias_forward is not None
return self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=layer.w13_weight,
@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
moe_runner_config=moe_runner_config,
w1_pcg=None,
w2_pcg=None,
)
else:
assert self.triton_kernel_moe_forward is not None
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
moe_runner_config=moe_runner_config,
)
else:
if _use_aiter:
assert not no_combine, "unsupported"
assert not moe_runner_config.no_combine, "unsupported"
topk_weights, topk_ids, _ = topk_output
if apply_router_weight_on_input:
if moe_runner_config.apply_router_weight_on_input:
assert (
topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids,
activation=(
ActivationType.Silu
if activation == "silu"
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
),
)
@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output,
inplace=inplace and not no_combine,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
moe_runner_config=moe_runner_config,
)
def forward_cpu(
@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported."
assert (
moe_runner_config.activation == "silu"
), f"activation = {moe_runner_config.activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
if (
use_intel_amx_backend(layer)
and not moe_runner_config.apply_router_weight_on_input
):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer,
x,
topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
moe_runner_config,
)
def forward_npu(
@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer,
x,
topk_output,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
moe_runner_config,
)
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:

View File

@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self,
layer: EPMoE,
x: torch.Tensor,
topk_output: TopKOutput,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
**kwargs,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
# TODO(ch-wan): move it out of this class
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale,
layer.w2_input_scale,
)
if routed_scaling_factor is not None:
output *= routed_scaling_factor
if moe_runner_config.routed_scaling_factor is not None:
output *= moe_runner_config.routed_scaling_factor
return output

View File

@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
_is_fp8_fnuz = is_fp8_fnuz()
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self,
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
moe_runner_config=moe_runner_config,
use_fp8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -49,6 +49,7 @@ from sglang.srt.utils import (
)
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput
_is_cuda = is_cuda()
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
topk_output: TopKOutput,
*,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x
moe_runner_config.apply_router_weight_on_input, topk_weights, x
)
return torch.ops.sgl_kernel.fused_experts_cpu(
x,
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w13_weight,
layer.w2_weight,
topk_output=topk_output,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
moe_runner_config=moe_runner_config,
use_int8_w8a8=True,
per_channel_quant=True,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer,
x,
topk_output: TopKOutput,
**kwargs,
moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output

View File

@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin,
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe import is_tbo_enabled
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device",
"disable_chunked_prefix_cache",
"disable_radix_cache",
"enable_two_batch_overlap",
"tbo_token_distribution_threshold",
"enable_dp_lm_head",
"moe_a2a_backend",
"deepep_mode",
"enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
"ep_dispatch_algorithm",
"deepep_config",
"ep_num_redundant_experts",
"enable_nan_detection",
"flashinfer_mla_disable_ragged",
@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens",
"weight_loader_disable_mmap",
"enable_triton_kernel_moe",
"enable_flashinfer_mxfp4_moe",
"enable_multimodal",
"enable_symm_mem",
"quantization",

View File

@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
)
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import (
AbortReq,
CloseSessionReqInput,
@@ -245,6 +245,9 @@ class Scheduler(
)
)
# Init model config
self.model_config = ModelConfig.from_server_args(server_args)
# Init inter-process communication
context = zmq.Context(2)
self.idle_sleeper = None
@@ -292,6 +295,9 @@ class Scheduler(
# Init tokenizer
self.init_tokenizer()
# Init moe config
self.init_moe_config()
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser(
@@ -538,8 +544,6 @@ class Scheduler(
def init_tokenizer(self):
server_args = self.server_args
self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init:
@@ -761,6 +765,10 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = []
def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)
@DynamicGradMode()
def event_loop_normal(self):
"""A normal scheduler loop."""
@@ -1823,11 +1831,6 @@ class Scheduler(
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
self.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)
@@ -1922,9 +1925,6 @@ class Scheduler(
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool,
disable_overlap_schedule: bool,
):
@@ -1972,9 +1972,6 @@ class Scheduler(
is_extend_in_batch,
*tbo_preparer.prepare_all_gather(
local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
),
],
dtype=torch.int64,

View File

@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.quantization import (
deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer,
@@ -219,8 +218,6 @@ class ModelRunner:
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
)

View File

@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype)
self.topk = TopK(
self.top_k,
renormalize=True,
)
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.ws = nn.Parameter(
torch.empty(
self.num_total_experts,
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe(
hidden_states,
self.ws,
self.w2s,
router_logits,
self.top_k,
renormalize=True,
inplace=True,
topk_output,
self.moe_runner_config,
)
if self.tp_size > 1:
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
position_ids: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.norm_1(hidden_states)
x = self.attn(

View File

@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
w1=self.w1,
w2=self.w2,
topk_output=topk_output,
inplace=True,
moe_runner_config=MoeRunnerConfig(inplace=True),
)
if self.config.n_shared_experts is not None:

View File

@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
)
self.shared_experts_is_int8 = False
@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["moe_a2a_backend"].is_deepep()
if get_moe_a2a_backend().is_deepep()
else {}
),
)
@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"],
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def get_moe_weights(self):
return [
@@ -484,13 +460,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda:
@@ -520,13 +490,7 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter:
@@ -2478,17 +2442,15 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
)
if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += (
get_moe_impl_class().make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts
)
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None

View File

@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",

View File

@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz,
@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model,
DeepseekV2MoE,
)
from sglang.srt.two_batch_overlap import (
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import (
BumpAllocator,
LazyValue,
@@ -414,19 +411,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
)
self.topk = (
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
if not should_use_flashinfer_trtllm_moe()
else None
self.topk = TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
)
self.shared_experts_is_int8 = False
@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"],
deepep_mode=get_deepep_mode(),
async_finish=True,
return_recv_hook=True,
)
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def forward_normal_dual_stream(
self,
@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",

View File

@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.moe_layer_freq = 1
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.dp_size = get_local_attention_dp_size()
self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = (
@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",

View File

@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention
@@ -110,16 +110,13 @@ class GptOssSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit
self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.gemm1_clamp_limit = config.swiglu_limit
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
else:
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
self.top_k = config.num_experts_per_tok
experts_type = get_moe_impl_class()
@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module):
quant_config.get_name() if quant_config is not None else None
)
extra_kwargs = {
"enable_flashinfer_cutlass_moe": global_server_args_dict[
"enable_flashinfer_cutlass_moe"
],
# for moe gate_up_proj and down_proj and their bias loading
"use_weight_loader_fused": quant_config_name != "mxfp4",
"use_weight_loader_fused": quant_config_name
!= "mxfp4"
}
self.experts = experts_type(
num_experts=config.num_local_experts
@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module):
intermediate_size=config.intermediate_size,
quant_config=quant_config,
activation=self.activation,
activation_alpha=self.activation_alpha,
swiglu_limit=self.swiglu_limit,
gemm1_alpha=self.gemm1_alpha,
gemm1_clamp_limit=self.gemm1_clamp_limit,
with_bias=True,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
**extra_kwargs,
)
@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module):
forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, should_allreduce_fusion)
else:
raise Exception("forward_deepep branch not implemented yet")
@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module):
should_allreduce_fusion: bool = False,
) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output)
if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True
@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module):
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",

View File

@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
params_dtype=params_dtype,
reduce_results=True,
quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts",
)

View File

@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module):
intermediate_size=intermediate_size,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
activation="gelu",
**kwargs,
)

View File

@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module):
]
expert_params_mapping = []
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",

View File

@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",

View File

@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
rope_theta = config.rope_theta
rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()

View File

@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda

View File

@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
intermediate_size=intermediate_size,
params_dtype=params_dtype,
quant_config=quant_config,
tp_size=tp_size,
prefix=add_prefix("experts", prefix),
)

View File

@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
intermediate_size=intermediate_size,
reduce_results=True,
quant_config=quant_config,
tp_size=tp_size,
layer_id=layer_id,
prefix=add_prefix("experts", prefix),
)

View File

@@ -17,8 +17,6 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import (
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
)
self.gate = ReplicatedLinear(
@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen2MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True

View File

@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
Qwen3MoeConfig = None
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
)
self.gate = ReplicatedLinear(
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix=add_prefix("gate", prefix),
)
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = (
@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, use_reduce_scatter)
else:
return self.forward_deepep(hidden_states, forward_batch)
@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen3MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True
@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",

View File

@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK
@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
prefix=add_prefix("gate", prefix),
)
if global_server_args_dict["moe_a2a_backend"].is_deepep():
if get_moe_a2a_backend().is_deepep():
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
]
)
self.pack_params()
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.router = ReplicatedLinear(
config.hidden_size,
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
quant_config=None,
prefix=add_prefix("router", prefix),
)
self.topk = TopK(
top_k=self.top_k,
renormalize=getattr(self.config, "norm_topk_prob", False),
)
if config.num_shared_experts is not None:
intermediate_size = config.intermediate_size * config.num_shared_experts
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe(
hidden_states,
self.w1,
self.w2,
router_logits,
self.top_k,
renormalize=getattr(self.config, "norm_topk_prob", False),
inplace=True,
topk_output,
self.moe_runner_config,
)
if self.config.num_shared_experts is not None:

View File

@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip,
is_port_available,
is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address,
nullable_str,
)
@@ -175,9 +176,15 @@ class ServerArgs:
# Expert parallelism
ep_size: int = 1
moe_a2a_backend: Optional[Literal["deepep"]] = None
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
moe_a2a_backend: Literal["none", "deepep"] = "none"
moe_runner_backend: Literal[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0
@@ -250,8 +257,6 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
scheduler_recv_interval: int = 1
# Debug tensor dumps
@@ -282,6 +287,9 @@ class ServerArgs:
# Deprecated arguments
enable_ep_moe: bool = False
enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
def __post_init__(self):
# Check deprecated arguments
@@ -298,6 +306,21 @@ class ServerArgs:
print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
)
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
# Set missing default values
if self.tokenizer_path is None:
@@ -517,7 +540,7 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel
if self.enable_flashinfer_cutlass_moe:
if self.moe_runner_backend == "flashinfer_cutlass":
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
@@ -527,7 +550,7 @@ class ServerArgs:
self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.enable_flashinfer_trtllm_moe:
if self.moe_runner_backend == "flashinfer_trtllm":
if not self.disable_shared_experts_fusion:
self.disable_shared_experts_fusion = True
logger.warning(
@@ -556,7 +579,7 @@ class ServerArgs:
self.ep_dispatch_algorithm = "static"
if self.enable_eplb:
assert self.ep_size > 1 or self.moe_a2a_backend is not None
assert self.ep_size > 1
if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None
@@ -1446,19 +1469,22 @@ class ServerArgs:
parser.add_argument(
"--moe-a2a-backend",
type=str,
choices=["deepep"],
choices=["none", "deepep"],
default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
"--moe-runner-backend",
type=str,
choices=[
"auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
)
parser.add_argument(
"--enable-flashinfer-allreduce-fusion",
@@ -1825,11 +1851,6 @@ class ServerArgs:
action="store_true",
help="Enable returning hidden states with responses.",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="Use triton moe grouped gemm kernel.",
)
parser.add_argument(
"--enable-flashinfer-mxfp4-moe",
action="store_true",
@@ -1965,6 +1986,21 @@ class ServerArgs:
action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
)
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="(Deprecated) Use triton moe grouped gemm kernel.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
@@ -2143,18 +2179,21 @@ class ServerArgs:
)
if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True
self.enable_triton_kernel_moe = False
self.moe_runner_backend = "flashinfer_mxfp4"
logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
)
else:
if self.enable_triton_kernel_moe:
if self.moe_runner_backend == "triton_kernel":
assert (
self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1:
self.enable_triton_kernel_moe = True
if (
self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
)

View File

@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
CommunicateSummableTensorPairFn,
ScatterMode,
)
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_tbo_token_distribution_threshold,
is_tbo_enabled,
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import (
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
left_sum = sum(extend_lens[:vanilla_split_seq_index])
overall_sum = sum(extend_lens)
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
threshold = get_tbo_token_distribution_threshold()
assert threshold <= 0.5, f"{threshold=}"
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
1 - threshold
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]:
if not is_tbo_enabled():
return
token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
def prepare_all_gather(
self,
local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
):
deepep_mode = get_deepep_mode()
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
enable_two_batch_overlap = is_tbo_enabled()
self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None:
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
and not local_batch.forward_mode.is_target_verify()
)
and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY)
and (resolved_deepep_mode.is_low_latency())
)
else:
self.local_tbo_split_seq_index = 0
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
"req_to_token_pool",
"token_to_kv_pool",
"can_run_dp_cuda_graph",
"dp_padding_mode",
"global_forward_mode",
"spec_algorithm",
"capture_hidden_mode",
@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
tbo_children=None,
global_num_tokens_gpu=None,
global_num_tokens_cpu=None,
dp_padding_mode=None,
global_dp_buffer_len=global_dp_buffer_len,
global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None,
@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
class MaybeTboDeepEPDispatcher:
def __init__(self, **kwargs):
num_inner_dispatchers = (
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
)
num_inner_dispatchers = 2 if is_tbo_enabled() else 1
self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
]

View File

@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args):
return True
elif not server_args.enable_dp_lm_head:
return True
elif server_args.moe_a2a_backend is None:
elif server_args.moe_a2a_backend == "none":
return True
else:
return (
@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args):
Check if the input of attention is scattered.
"""
assert server_args.moe_dense_tp_size in [1, None]
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1:
if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size
else:

View File

@@ -6,7 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_fp8,
per_token_group_quant_fp8,
@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
renormalize=False,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
out = fused_moe(
a,
@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
w2_scale=w2_s,
block_shape=block_size,
)
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))

View File

@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase
@@ -22,35 +22,26 @@ def ep_moe(
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
topk_config: TopKConfig,
# ep config
num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128,
start_expert_id: int = 0,
end_expert_id: int = 127,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
):
use_blockwise_fp8 = block_shape is not None
topk_weights, topk_ids, _ = select_experts(
top_k = topk_config.top_k
topk_output = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
# correction_bias=correction_bias, #skip this in test
custom_routing_function=custom_routing_function,
topk_config=topk_config,
)
topk_weights, topk_ids, _ = topk_output
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1
topk_config = TopKConfig(
top_k=topk,
renormalize=False,
)
with torch.inference_mode():
out = ep_moe(
hidden_states=a,
w1=w1,
w2=w2,
router_logits=score,
top_k=topk,
renormalize=False,
topk_config=topk_config,
use_fp8_w8a8=True,
w1_scale_inv=w1_s,
w2_scale_inv=w2_s,
@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
w1=w1_ref,
w2=w2_ref,
router_logits=score,
top_k=topk,
renormalize=False,
topk_config=topk_config,
use_fp8_w8a8=False,
w1_scale_inv=None,
w2_scale_inv=None,

View File

@@ -6,7 +6,7 @@ import pytest
import torch
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
@@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
s_strides2 = c_strides2
score = torch.randn((M, E), dtype=dtype, device=device)
topk_weights, topk_ids, _ = select_experts(
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
topk_weights, topk_ids, _ = topk_output
expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E

View File

@@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip(
@@ -163,11 +163,12 @@ def check_moe(
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = select_experts(
topk_output = select_experts(
hidden_states=a,
router_logits=score,
top_k=topk,
topk_config=TopKConfig(top_k=topk, renormalize=False),
)
topk_weights, topk_ids, _ = topk_output
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)