[6/N] MoE Refactor: Cleanup MoE-related configs (#8849)

This commit is contained in:
Cheng Wan
2025-08-14 21:14:53 -07:00
committed by GitHub
parent 584e1ab2d0
commit 295895120d
69 changed files with 956 additions and 1037 deletions

View File

@@ -11,6 +11,7 @@ import triton
from ray.experimental.tqdm_ray import tqdm from ray.experimental.tqdm_ray import tqdm
from transformers import AutoConfig from transformers import AutoConfig
from sglang.srt.layers.moe.fused_moe_triton import override_config
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_moe, fused_moe,
get_config_dtype_str, get_config_dtype_str,
@@ -18,7 +19,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
get_default_config, get_default_config,
get_moe_configs, get_moe_configs,
) )
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip() _is_hip = is_hip()
@@ -117,17 +119,23 @@ def benchmark_config(
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_output = select_experts(x, input_gating, topk, renormalize=True) topk_config = TopKConfig(
top_k=topk,
renormalize=True,
)
topk_output = select_experts(x, input_gating, topk_config)
def prepare(i: int): def prepare(i: int):
input_gating = gating_output[i] input_gating = gating_output[i]
new_topk_output = select_experts(x, input_gating, topk, renormalize=True) new_topk_output = select_experts(x, input_gating, topk_config)
topk_output.topk_weights.copy_(new_topk_output.topk_weights) topk_output.topk_weights.copy_(new_topk_output.topk_weights)
topk_output.topk_ids.copy_(new_topk_output.topk_ids) topk_output.topk_ids.copy_(new_topk_output.topk_ids)
topk_output.router_logits.copy_(new_topk_output.router_logits) topk_output.router_logits.copy_(new_topk_output.router_logits)
def run(): def run():
from sglang.srt.layers.moe.fused_moe_triton import override_config moe_runner_config = MoeRunnerConfig(
inplace=True,
)
with override_config(config): with override_config(config):
fused_moe( fused_moe(
@@ -135,7 +143,7 @@ def benchmark_config(
w1, w1,
w2, w2,
topk_output, topk_output,
inplace=True, moe_runner_config=moe_runner_config,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,

View File

@@ -213,12 +213,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| Arguments | Description | Defaults | | Arguments | Description | Defaults |
|-----------|-------------|----------| |-----------|-------------|----------|
| `--ep-size` | The expert parallelism size. | 1 | | `--ep-size` | The expert parallelism size. | 1 |
| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | None | | `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none |
| `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False | | `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' |
| `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False |
| `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto |
| `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 |
| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None | | `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None |
| `--init-expert-location` | Initial location of EP experts. | trivial | | `--init-expert-location` | Initial location of EP experts. | trivial |
| `--enable-eplb` | Enable EPLB algorithm. | False | | `--enable-eplb` | Enable EPLB algorithm. | False |
| `--eplb-algorithm` | Chosen EPLB algorithm. | auto | | `--eplb-algorithm` | Chosen EPLB algorithm. | auto |
@@ -280,7 +279,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False | | `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False |
| `--disable-fast-image-processor` | Disable fast image processor. | False | | `--disable-fast-image-processor` | Disable fast image processor. | False |
| `--enable-return-hidden-states` | Enable returning hidden states. | False | | `--enable-return-hidden-states` | Enable returning hidden states. | False |
| `--enable-triton-kernel-moe` | Enable Triton kernel for MoE. | False |
## Debug tensor dumps ## Debug tensor dumps

View File

@@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.distributed.parallel_state import destroy_distributed_environment
from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.entrypoints.engine import _set_envs_and_config
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.managers.scheduler import Scheduler from sglang.srt.managers.scheduler import Scheduler
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
disable_cuda_graph=model_runner.server_args.disable_cuda_graph, disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE, spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None, speculative_num_draft_tokens=None,
enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
model_runner.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
) )

View File

@@ -25,7 +25,6 @@ import torch
import torch.distributed import torch.distributed
from sglang.srt.eplb.expert_location import ExpertLocationMetadata from sglang.srt.eplb.expert_location import ExpertLocationMetadata
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import Withable, get_bool_env_var from sglang.srt.utils import Withable, get_bool_env_var
@@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC):
) )
if server_args.expert_distribution_recorder_mode == "stat_approx": if server_args.expert_distribution_recorder_mode == "stat_approx":
if server_args.moe_a2a_backend is not None and ( if server_args.moe_a2a_backend != "none" and (
server_args.deepep_mode == "normal" server_args.deepep_mode == "normal"
): ):
return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
else: else:
raise NotImplementedError raise NotImplementedError
if server_args.moe_a2a_backend is not None: if server_args.moe_a2a_backend != "none":
if server_args.deepep_mode == "normal": if server_args.deepep_mode == "normal":
return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
elif server_args.deepep_mode == "low_latency": elif server_args.deepep_mode == "low_latency":

View File

@@ -17,7 +17,7 @@ from enum import Enum, auto
from functools import partial from functools import partial
from typing import Dict, Optional from typing import Dict, Optional
import torch.distributed import torch
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
@@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer, get_global_dp_buffer,
get_local_dp_buffer, get_local_dp_buffer,
) )
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -111,7 +112,7 @@ class LayerScatterModes:
if context.is_layer_sparse: if context.is_layer_sparse:
return ( return (
ScatterMode.SCATTERED ScatterMode.SCATTERED
if not global_server_args_dict["moe_a2a_backend"].is_standard() if not get_moe_a2a_backend().is_none()
else ScatterMode.FULL else ScatterMode.FULL
) )
else: else:

View File

@@ -0,0 +1,29 @@
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.utils import (
DeepEPMode,
MoeA2ABackend,
MoeRunnerBackend,
get_deepep_config,
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
get_tbo_token_distribution_threshold,
initialize_moe_config,
is_tbo_enabled,
should_use_flashinfer_trtllm_moe,
)
__all__ = [
"DeepEPMode",
"MoeA2ABackend",
"MoeRunnerConfig",
"MoeRunnerBackend",
"initialize_moe_config",
"get_moe_a2a_backend",
"get_moe_runner_backend",
"get_deepep_mode",
"should_use_flashinfer_trtllm_moe",
"is_tbo_enabled",
"get_tbo_token_distribution_threshold",
"get_deepep_config",
]

View File

@@ -1,11 +1,17 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Union
import torch import torch
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather, ep_gather,
ep_scatter, ep_scatter,
@@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
) )
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import ( from sglang.srt.layers.quantization.fp8 import Fp8Config
Fp8Config,
Fp8MoEMethod,
get_tile_tokens_dim,
)
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
@@ -89,12 +90,11 @@ class EPMoE(FusedMoE):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False, with_bias: bool = False,
): ):
super().__init__( super().__init__(
@@ -106,13 +106,12 @@ class EPMoE(FusedMoE):
top_k=top_k, top_k=top_k,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=prefix, prefix=prefix,
activation=activation, activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input, # apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha, gemm1_alpha=gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias, with_bias=with_bias,
) )
@@ -163,7 +162,8 @@ class EPMoE(FusedMoE):
) )
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device hidden_states_device = hidden_states.device
@@ -327,8 +327,8 @@ class EPMoE(FusedMoE):
m_max * self.start_expert_id, m_max * self.start_expert_id,
BLOCK_SIZE=512, BLOCK_SIZE=512,
) )
if self.routed_scaling_factor is not None: if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.routed_scaling_factor output *= self.moe_runner_config.routed_scaling_factor
return output return output
@@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
): ):
super().__init__( super().__init__(
num_experts=num_experts, num_experts=num_experts,
@@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE):
num_fused_shared_experts=num_fused_shared_experts, num_fused_shared_experts=num_fused_shared_experts,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=prefix, prefix=prefix,
activation=activation, activation=activation,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
self.deepep_mode = deepep_mode self.deepep_mode = get_deepep_mode()
# TODO: move to the beginning of the file # TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group from sglang.srt.distributed.parallel_state import get_tp_group
@@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE):
num_local_experts=self.num_local_experts, num_local_experts=self.num_local_experts,
hidden_size=hidden_size, hidden_size=hidden_size,
params_dtype=params_dtype, params_dtype=params_dtype,
deepep_mode=deepep_mode, deepep_mode=self.deepep_mode,
async_finish=True, # TODO async_finish=True, # TODO
return_recv_hook=True, return_recv_hook=True,
) )
@@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE):
) )
def moe_impl(self, dispatch_output: DispatchOutput): def moe_impl(self, dispatch_output: DispatchOutput):
from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker
if _use_aiter: if _use_aiter:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
return self.forward_aiter(dispatch_output) return self.forward_aiter(dispatch_output)
if _is_npu: if _is_npu:
assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output)
return self.forward_npu(dispatch_output) return self.forward_npu(dispatch_output)
if dispatch_output.format.is_deepep_normal(): if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_contiguous(dispatch_output) return self.forward_deepgemm_contiguous(dispatch_output)
elif dispatch_output.format.is_deepep_ll(): elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
return self.forward_deepgemm_masked(dispatch_output) return self.forward_deepgemm_masked(dispatch_output)
else: else:
@@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE):
def forward_aiter( def forward_aiter(
self, self,
dispatch_output: DeepEPNormalOutput, dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput],
): ):
hidden_states, topk_idx, topk_weights = ( hidden_states, topk_idx, topk_weights = (
dispatch_output.hidden_states, dispatch_output.hidden_states,
@@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE):
quant_type=QuantType.per_128x128, quant_type=QuantType.per_128x128,
activation=( activation=(
ActivationType.Silu ActivationType.Silu
if self.activation == "silu" if self.moe_runner_config.activation == "silu"
else ActivationType.Gelu else ActivationType.Gelu
), ),
expert_mask=self.expert_mask, expert_mask=self.expert_mask,
@@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE):
) )
hidden_states_fp8, hidden_states_scale = hidden_states_fp8 hidden_states_fp8, hidden_states_scale = hidden_states_fp8
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None: if num_recv_tokens_per_expert is None:
return hidden_states_fp8.bfloat16() return hidden_states_fp8.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert) all_tokens = sum(num_recv_tokens_per_expert)
@@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE):
): ):
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None assert self.quant_method is not None
assert self.activation == "silu" assert self.moe_runner_config.activation == "silu"
# GroupGemm-0 # GroupGemm-0
num_groups, m, k = hidden_states_fp8[0].size() num_groups, m, k = hidden_states_fp8[0].size()
@@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE):
def get_moe_impl_class(): def get_moe_impl_class():
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
return DeepEPMoE return DeepEPMoE
# NEW: Direct FP4 detection (bypasses EP requirements) # NEW: Direct FP4 detection (bypasses EP requirements)
# Check for FP4 quantization with TRTLLM flag, regardless of EP # Check for FP4 quantization with TRTLLM flag, regardless of EP
if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False): if get_moe_runner_backend().is_flashinfer_trtllm():
try: try:
# Check the quantization argument directly # Check the quantization argument directly
quantization = global_server_args_dict.get("quantization") quantization = global_server_args_dict.get("quantization")
@@ -803,7 +804,7 @@ def get_moe_impl_class():
if should_use_flashinfer_trtllm_moe(): if should_use_flashinfer_trtllm_moe():
return FlashInferFusedMoE return FlashInferFusedMoE
if global_server_args_dict["enable_flashinfer_cutlass_moe"]: if get_moe_runner_backend().is_flashinfer_cutlass():
return FusedMoE return FusedMoE
if get_moe_expert_parallel_world_size() > 1: if get_moe_expert_parallel_world_size() > 1:
return EPMoE return EPMoE

View File

@@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
""" """
from typing import Callable, Optional
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
def fused_moe_forward_native( def fused_moe_forward_native(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if apply_router_weight_on_input: if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError() raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
@@ -33,12 +27,12 @@ def fused_moe_forward_native(
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids] w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
if activation == "silu": if moe_runner_config.activation == "silu":
x1 = F.silu(x1) x1 = F.silu(x1)
elif activation == "gelu": elif moe_runner_config.activation == "gelu":
x1 = F.gelu(x1) x1 = F.gelu(x1)
else: else:
raise ValueError(f"Unsupported activation: {activation=}") raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -47,16 +41,11 @@ def fused_moe_forward_native(
def moe_forward_native( def moe_forward_native(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if apply_router_weight_on_input: if moe_runner_config.apply_router_weight_on_input:
raise NotImplementedError() raise NotImplementedError()
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
@@ -72,12 +61,12 @@ def moe_forward_native(
sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens = x[idxs // topk_ids.shape[1]]
tokens_per_expert = tokens_per_expert.cpu().numpy() tokens_per_expert = tokens_per_expert.cpu().numpy()
if activation == "silu": if moe_runner_config.activation == "silu":
act = SiluAndMul() act = SiluAndMul()
elif activation == "gelu": elif moe_runner_config.activation == "gelu":
act = GeluAndMul() act = GeluAndMul()
else: else:
raise ValueError(f"Unsupported activation: {activation=}") raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}")
outputs = [] outputs = []
start_idx = 0 start_idx = 0

View File

@@ -2,17 +2,20 @@
"""Fused MoE kernel.""" """Fused MoE kernel."""
from __future__ import annotations
import functools import functools
import json import json
import logging import logging
import os import os
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8, per_token_group_quant_fp8,
scaled_fp8_quant, scaled_fp8_quant,
@@ -1025,8 +1028,8 @@ def inplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> None: ) -> None:
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
@@ -1053,8 +1056,8 @@ def inplace_fused_experts(
block_shape, block_shape,
False, False,
routed_scaling_factor, routed_scaling_factor,
activation_alpha, gemm1_alpha,
swiglu_limit, gemm1_limit,
) )
@@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> None: ) -> None:
pass pass
@@ -1119,8 +1122,8 @@ def outplace_fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
@@ -1147,8 +1150,8 @@ def outplace_fused_experts(
block_shape, block_shape,
no_combine=no_combine, no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha, gemm1_alpha=gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_limit=gemm1_limit,
) )
@@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
@@ -1194,12 +1197,10 @@ def fused_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
@@ -1212,14 +1213,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
if inplace: if moe_runner_config.inplace:
assert not no_combine, "no combine + inplace makes no sense" assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts( torch.ops.sglang.inplace_fused_experts(
hidden_states, hidden_states,
w1, w1,
@@ -1228,8 +1225,8 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
activation, moe_runner_config.activation,
apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
@@ -1242,9 +1239,9 @@ def fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
routed_scaling_factor, moe_runner_config.routed_scaling_factor,
activation_alpha, moe_runner_config.gemm1_alpha,
swiglu_limit, moe_runner_config.gemm1_clamp_limit,
) )
return hidden_states return hidden_states
else: else:
@@ -1256,8 +1253,8 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
activation, moe_runner_config.activation,
apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
@@ -1270,10 +1267,10 @@ def fused_experts(
a1_scale, a1_scale,
a2_scale, a2_scale,
block_shape, block_shape,
no_combine=no_combine, no_combine=moe_runner_config.no_combine,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=moe_runner_config.routed_scaling_factor,
activation_alpha=activation_alpha, gemm1_alpha=moe_runner_config.gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_limit=moe_runner_config.gemm1_clamp_limit,
) )
@@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
@torch.compile @torch.compile
def swiglu_with_alpha_and_limit(x, alpha, limit): def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit):
gate, up = x[..., ::2], x[..., 1::2] gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=limit) gate = gate.clamp(min=None, max=gemm1_limit)
up = up.clamp(min=-limit, max=limit) up = up.clamp(min=-gemm1_limit, max=gemm1_limit)
return gate * torch.sigmoid(gate * alpha) * (up + 1) return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1)
def fused_experts_impl( def fused_experts_impl(
@@ -1402,8 +1399,8 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
): ):
padded_size = padding_size padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
@@ -1533,12 +1530,12 @@ def fused_experts_impl(
block_shape=block_shape, block_shape=block_shape,
) )
if activation == "silu": if activation == "silu":
if activation_alpha is not None: if gemm1_alpha is not None:
assert swiglu_limit is not None assert gemm1_limit is not None
intermediate_cache2 = swiglu_with_alpha_and_limit( intermediate_cache2 = swiglu_with_alpha_and_limit(
intermediate_cache1.view(-1, N), intermediate_cache1.view(-1, N),
activation_alpha, gemm1_alpha,
swiglu_limit, gemm1_limit,
) )
elif _is_cuda: elif _is_cuda:
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
@@ -1547,10 +1544,8 @@ def fused_experts_impl(
intermediate_cache2, intermediate_cache1.view(-1, N) intermediate_cache2, intermediate_cache1.view(-1, N)
) )
elif activation == "gelu": elif activation == "gelu":
assert ( assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu"
activation_alpha is None assert gemm1_limit is None, "gemm1_limit is not supported for gelu"
), "activation_alpha is not supported for gelu"
assert swiglu_limit is None, "swiglu_limit is not supported for gelu"
if _is_cuda: if _is_cuda:
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
else: else:
@@ -1641,12 +1636,10 @@ def fused_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(),
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
@@ -1659,10 +1652,6 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1672,11 +1661,10 @@ def fused_moe(
- hidden_states (torch.Tensor): The input tensor to the MoE layer. - hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights. - w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights. - w2 (torch.Tensor): The second set of expert weights.
- topk_output (TopKOutput): The top-k output of the experts. - topk_output (StandardTopKOutput): The top-k output of the experts.
- moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner.
- b1 (Optional[torch.Tensor]): Optional bias for w1. - b1 (Optional[torch.Tensor]): Optional bias for w1.
- b2 (Optional[torch.Tensor]): Optional bias for w2. - b2 (Optional[torch.Tensor]): Optional bias for w2.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False. products for w1 and w2. Defaults to False.
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
@@ -1696,9 +1684,9 @@ def fused_moe(
a2. a2.
- block_shape: (Optional[List[int]]): Optional block size for block-wise - block_shape: (Optional[List[int]]): Optional block size for block-wise
quantization. quantization.
- activation_alpha (Optional[float]): Optional alpha for the activation - gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation
function. function.
- swiglu_limit (Optional[float]): Optional limit for the swiglu activation - gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation
function. function.
Returns: Returns:
@@ -1710,11 +1698,9 @@ def fused_moe(
w1, w1,
w2, w2,
topk_output, topk_output,
moe_runner_config=moe_runner_config,
b1=b1, b1=b1,
b2=b2, b2=b2,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
@@ -1727,8 +1713,4 @@ def fused_moe(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )

View File

@@ -1,10 +1,6 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import datetime
import glob
import logging import logging
import os
import sys
from enum import Enum from enum import Enum
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory, use_symmetric_memory,
) )
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.moe import (
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe MoeRunnerConfig,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
@@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module):
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = False, reduce_results: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "", prefix: str = "",
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
@@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module):
inplace: bool = True, inplace: bool = True,
no_combine: bool = False, no_combine: bool = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False, gemm1_alpha: Optional[float] = None,
activation_alpha: Optional[float] = None, gemm1_clamp_limit: Optional[float] = None,
swiglu_limit: Optional[float] = None,
use_weight_loader_fused: bool = False, use_weight_loader_fused: bool = False,
with_bias=False, with_bias=False,
): ):
@@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None self.expert_map_cpu = None
self.expert_map_gpu = None self.expert_map_gpu = None
# For activation self.moe_runner_config = MoeRunnerConfig(
self.activation_alpha = activation_alpha activation=activation,
self.swiglu_limit = swiglu_limit apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
)
enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass()
if enable_flashinfer_cutlass_moe and quant_config is None: if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.") logger.warning("Disable flashinfer MoE when quantization config is None.")
@@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module):
* self.num_local_experts * self.num_local_experts
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
self.routed_scaling_factor = routed_scaling_factor
assert intermediate_size % self.moe_tp_size == 0 assert intermediate_size % self.moe_tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size
self.reduce_results = reduce_results self.reduce_results = reduce_results
self.activation = activation
self.apply_router_weight_on_input = apply_router_weight_on_input
self.use_presharded_weights = use_presharded_weights self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
self.use_triton_kernels = (
not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
if quant_config is None: if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels self.use_triton_kernels
@@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module):
assert self.quant_method is not None assert self.quant_method is not None
self.quant_config = quant_config self.quant_config = quant_config
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get( self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4()
"enable_flashinfer_mxfp4_moe", False
)
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if ( if (
self.quant_config is not None self.quant_config is not None
and self.quant_config.get_name() == "mxfp4" and self.quant_config.get_name() == "mxfp4"
and self.use_enable_flashinfer_mxfp4_moe and self.use_flashinfer_mxfp4_moe
): ):
hidden_size = round_up(hidden_size, 256) hidden_size = round_up(hidden_size, 256)
self.quant_method.create_weights( self.quant_method.create_weights(
@@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module):
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded." f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
) )
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1] origin_hidden_states_dim = hidden_states.shape[-1]
assert self.quant_method is not None assert self.quant_method is not None
@@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module):
# If we are in EP mode, we need to move the expert map to GPU. # If we are in EP mode, we need to move the expert map to GPU.
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
if self.expert_map_gpu is not None and isinstance( if self.expert_map_gpu is not None:
topk_output, StandardTopKOutput if TopKOutputChecker.format_is_standard(topk_output):
):
topk_output = topk_output._replace( topk_output = topk_output._replace(
topk_ids=self.expert_map_gpu[topk_output.topk_ids] topk_ids=self.expert_map_gpu[topk_output.topk_ids]
) )
elif TopKOutputChecker.format_is_triton_kernel(topk_output):
raise NotImplementedError()
# Matrix multiply. # Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm: with use_symmetric_memory(get_tp_group()) as sm:
kwargs = {}
if self.activation_alpha is not None:
kwargs["activation_alpha"] = self.activation_alpha
if self.swiglu_limit is not None:
kwargs["swiglu_limit"] = self.swiglu_limit
final_hidden_states = self.quant_method.apply( final_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=hidden_states, x=hidden_states,
topk_output=topk_output, topk_output=topk_output,
activation=self.activation, moe_runner_config=self.moe_runner_config,
apply_router_weight_on_input=self.apply_router_weight_on_input,
routed_scaling_factor=self.routed_scaling_factor,
**(
dict(
tp_rank=self.moe_tp_rank,
tp_size=self.moe_tp_size,
ep_rank=self.moe_ep_rank,
ep_size=self.moe_ep_size,
)
if self.quant_method.__class__.__name__
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
**kwargs,
) )
sm.tag(final_hidden_states) sm.tag(final_hidden_states)
@@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module):
class FlashInferFusedMoE(FusedMoE): class FlashInferFusedMoE(FusedMoE):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
def forward(self, hidden_states: torch.Tensor, topk_output: tuple): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
assert self.use_flashinfer_trtllm_moe assert self.use_flashinfer_trtllm_moe
assert ( assert (
self.activation == "silu" self.activation == "silu"
@@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE):
self.num_fused_shared_experts == 0 self.num_fused_shared_experts == 0
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
# TRTLLM mode expects (TopK_config, router_logits) tuple assert TopKOutputChecker.format_is_bypassed(topk_output)
if not isinstance(topk_output, tuple) or len(topk_output) != 2:
raise ValueError(
f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
# Matrix multiply. # Matrix multiply.
final_hidden_states = self.quant_method.apply_with_router_logits( final_hidden_states = self.quant_method.apply_with_router_logits(
layer=self, layer=self,
x=hidden_states, x=hidden_states,
router_logits=router_logits, topk_output=topk_output,
activation=self.activation, moe_runner_config=self.moe_runner_config,
routed_scaling_factor=self.routed_scaling_factor,
) )
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
@@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE):
"""FP4 TRTLLM MoE implementation using FlashInfer.""" """FP4 TRTLLM MoE implementation using FlashInfer."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# Extract DeepSeek-specific parameters
renormalize = kwargs.pop("renormalize", True)
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
num_expert_group = kwargs.pop("num_expert_group", None)
topk_group = kwargs.pop("topk_group", None)
correction_bias = kwargs.pop("correction_bias", None)
# Extract additional TopK parameters that were previously extracted in forward
routed_scaling_factor = kwargs.pop("routed_scaling_factor", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Store DeepSeek parameters
self.renormalize = renormalize
self.num_fused_shared_experts = num_fused_shared_experts
self.use_grouped_topk = use_grouped_topk
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
# Helper: quantize hidden states to FP4 each forward pass # Helper: quantize hidden states to FP4 each forward pass
# --------------------------------------------------------------------- # ---------------------------------------------------------------------
@@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE):
return hs_fp4, hs_sf return hs_fp4, hs_sf
def forward(self, hidden_states: torch.Tensor, topk_output): def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
"""Forward pass using FP4 TRTLLM kernel. """Forward pass using FP4 TRTLLM kernel.
Args: Args:
hidden_states: Input tensor hidden_states: Input tensor
topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode topk_output: TopKOutput object with Bypassed format
""" """
assert TopKOutputChecker.format_is_bypassed(topk_output)
# TRTLLM mode expects (TopK_config, router_logits) tuple router_logits = topk_output.router_logits
if not isinstance(topk_output, tuple) or len(topk_output) != 2: topk_config = topk_output.topk_config
raise ValueError(
f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}"
)
_, router_logits = topk_output
hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states) hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states)
@@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE):
result = trtllm_fp4_block_scale_moe( result = trtllm_fp4_block_scale_moe(
routing_logits=router_logits, routing_logits=router_logits,
routing_bias=self.correction_bias.to(hidden_states.dtype), routing_bias=topk_config.correction_bias.to(hidden_states.dtype),
hidden_states=hs_fp4, hidden_states=hs_fp4,
hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(), hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(),
gemm1_weights=self.gemm1_weights_fp4_shuffled.data, gemm1_weights=self.gemm1_weights_fp4_shuffled.data,
@@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE):
output1_scale_gate_scalar=self.g1_alphas.data, output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data, output2_scale_scalar=self.g2_alphas.data,
num_experts=self.num_experts, num_experts=self.num_experts,
top_k=self.top_k, top_k=topk_config.top_k,
n_group=self.num_expert_group, n_group=topk_config.num_expert_group,
topk_group=self.topk_group, topk_group=topk_config.topk_group,
intermediate_size=self.intermediate_size_per_partition, intermediate_size=self.intermediate_size_per_partition,
local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_expert_offset=self.moe_ep_rank * self.num_local_experts,
local_num_experts=self.num_local_experts, local_num_experts=self.num_local_experts,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.moe_runner_config.routed_scaling_factor,
tile_tokens_dim=_get_tile_tokens_dim( tile_tokens_dim=_get_tile_tokens_dim(
hidden_states.shape[0], self.top_k, self.num_local_experts hidden_states.shape[0], topk_config.top_k, self.num_local_experts
), ),
routing_method_type=RoutingMethodType.DeepSeekV3, routing_method_type=RoutingMethodType.DeepSeekV3,
do_finalize=True, do_finalize=True,

View File

@@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from triton_kernels.swiglu import swiglu_fn from triton_kernels.swiglu import swiglu_fn
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
@@ -55,8 +56,7 @@ def triton_kernel_moe_forward(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
inplace: bool = False, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
@@ -69,7 +69,10 @@ def triton_kernel_moe_forward(
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert topk_output.format.is_triton_kernel() from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts( return triton_kernel_fused_experts(
@@ -79,8 +82,8 @@ def triton_kernel_moe_forward(
routing_data, routing_data,
gather_idx, gather_idx,
scatter_idx, scatter_idx,
inplace=inplace, inplace=False, # triton kernel doesn't support inplace
activation=activation, activation=moe_runner_config.activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
@@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward(
w2_pcg, w2_pcg,
b2: torch.Tensor, b2: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
inplace: bool = False, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
per_channel_quant: bool = False, per_channel_quant: bool = False,
global_num_experts: int = -1, global_num_experts: int = -1,
@@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert topk_output.format.is_triton_kernel() from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_triton_kernel(topk_output)
routing_data, gather_idx, scatter_idx = topk_output routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts_with_bias( return triton_kernel_fused_experts_with_bias(
@@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward(
routing_data=routing_data, routing_data=routing_data,
gather_indx=gather_idx, gather_indx=gather_idx,
scatter_indx=scatter_idx, scatter_indx=scatter_idx,
inplace=inplace, inplace=False, # triton kernel doesn't support inplace
activation=activation, activation=moe_runner_config.activation,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts, global_num_experts=global_num_experts,
@@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward(
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape, block_shape=block_shape,
activation_alpha=activation_alpha, gemm1_alpha=moe_runner_config.gemm1_alpha,
swiglu_limit=swiglu_limit, gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit,
) )
@@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None, block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None, gemm1_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None, gemm1_clamp_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported" assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported" assert expert_map == None, "expert_map is not supported"
@@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias(
act = FusedActivation( act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
(activation_alpha, swiglu_limit), (gemm1_alpha, gemm1_clamp_limit),
2, 2,
) )

View File

@@ -0,0 +1,3 @@
from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig
__all__ = ["MoeRunnerConfig"]

View File

@@ -0,0 +1,13 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class MoeRunnerConfig:
activation: str = "silu"
apply_router_weight_on_input: bool = False
inplace: bool = True
no_combine: bool = False
routed_scaling_factor: Optional[float] = None
gemm1_alpha: Optional[float] = None
gemm1_clamp_limit: Optional[float] = None

View File

@@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
DispatchOutput, DispatchOutput,
DispatchOutputChecker,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.token_dispatcher.deepep import ( from sglang.srt.layers.moe.token_dispatcher.deepep import (
AscendDeepEPLLOutput,
DeepEPConfig, DeepEPConfig,
DeepEPDispatcher, DeepEPDispatcher,
DeepEPLLOutput, DeepEPLLOutput,
DeepEPNormalOutput, DeepEPNormalOutput,
) )
from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput
__all__ = [ __all__ = [
"AscendDeepEPLLOutput",
"BaseDispatcher", "BaseDispatcher",
"BaseDispatcherConfig", "BaseDispatcherConfig",
"DispatchOutput", "DispatchOutput",
"DispatchOutputFormat", "DispatchOutputFormat",
"DispatchOutputChecker",
"StandardDispatchOutput",
"DeepEPConfig", "DeepEPConfig",
"DeepEPDispatcher", "DeepEPDispatcher",
"DeepEPNormalOutput", "DeepEPNormalOutput",

View File

@@ -2,35 +2,76 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum, auto from enum import Enum, auto
from typing import Protocol, runtime_checkable from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable
import torch import torch
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
AscendDeepEPLLOutput,
DeepEPLLOutput,
DeepEPNormalOutput,
StandardDispatchOutput,
)
class MoEA2ABackend(Enum):
none = "none"
deepep = "deepep"
def is_none(self): class DispatchOutputChecker:
return self == MoEA2ABackend.none
def is_deepep(self): @staticmethod
return self == MoEA2ABackend.deepep def format_is_standard(
dispatch_output: DispatchOutput,
) -> TypeGuard[StandardDispatchOutput]:
return dispatch_output.format.is_standard()
@staticmethod
def format_is_deepep_normal(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPNormalOutput]:
return dispatch_output.format.is_deepep_normal()
@staticmethod
def format_is_deepep_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[DeepEPLLOutput]:
return dispatch_output.format.is_deepep_ll()
@staticmethod
def format_is_deepep(
dispatch_output: DispatchOutput,
) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
return dispatch_output.format.is_deepep()
@staticmethod
def format_is_ascent_ll(
dispatch_output: DispatchOutput,
) -> TypeGuard[AscendDeepEPLLOutput]:
return dispatch_output.format.is_ascent_ll()
class DispatchOutputFormat(Enum): class DispatchOutputFormat(Enum):
standard = auto()
deepep_normal = auto() STANDARD = auto()
deepep_ll = auto() DEEPEP_NORMAL = auto()
DEEPEP_LL = auto()
ASCENT_LL = auto()
def is_standard(self) -> bool: def is_standard(self) -> bool:
return self == DispatchOutputFormat.standard return self == DispatchOutputFormat.STANDARD
def is_deepep_normal(self) -> bool: def is_deepep_normal(self) -> bool:
return self == DispatchOutputFormat.deepep_normal return self == DispatchOutputFormat.DEEPEP_NORMAL
def is_deepep_ll(self) -> bool: def is_deepep_ll(self) -> bool:
return self == DispatchOutputFormat.deepep_ll return self == DispatchOutputFormat.DEEPEP_LL
def is_deepep(self) -> bool:
return self in [
DispatchOutputFormat.DEEPEP_NORMAL,
DispatchOutputFormat.DEEPEP_LL,
]
def is_ascent_ll(self) -> bool:
return self == DispatchOutputFormat.ASCENT_LL
@runtime_checkable @runtime_checkable

View File

@@ -2,27 +2,17 @@ from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union
TYPE_CHECKING,
List,
NamedTuple,
Optional,
Protocol,
Tuple,
Union,
runtime_checkable,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
BaseDispatcher, BaseDispatcher,
BaseDispatcherConfig, BaseDispatcherConfig,
DispatchOutput, DispatchOutput,
DispatchOutputFormat, DispatchOutputFormat,
) )
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_int_env_var, get_int_env_var,
@@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_normal return DispatchOutputFormat.DEEPEP_NORMAL
class DeepEPLLOutput(NamedTuple): class DeepEPLLOutput(NamedTuple):
@@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll return DispatchOutputFormat.DEEPEP_LL
class AscendDeepEPLLOutput(NamedTuple): class AscendDeepEPLLOutput(NamedTuple):
@@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.deepep_ll return DispatchOutputFormat.ASCENT_LL
assert isinstance(DeepEPNormalOutput, DispatchOutput) assert isinstance(DeepEPNormalOutput, DispatchOutput)
@@ -128,8 +118,8 @@ class DeepEPBuffer:
hidden_size: int, hidden_size: int,
param_bytes: int, param_bytes: int,
deepep_mode: DeepEPMode, deepep_mode: DeepEPMode,
num_max_dispatch_tokens_per_rank: int = None, num_max_dispatch_tokens_per_rank: int = -1,
num_experts: int = None, num_experts: int = -1,
): ):
if cls._buffer is not None: if cls._buffer is not None:
return cls._buffer return cls._buffer
@@ -156,8 +146,8 @@ class DeepEPBuffer:
num_rdma_bytes, num_rdma_bytes,
) )
if deepep_mode.enable_low_latency(): if deepep_mode.enable_low_latency():
assert num_max_dispatch_tokens_per_rank is not None assert num_max_dispatch_tokens_per_rank != -1
assert num_experts is not None and num_experts % group.size() == 0 assert num_experts != -1 and num_experts % group.size() == 0
num_rdma_bytes = max( num_rdma_bytes = max(
Buffer.get_low_latency_rdma_size_hint( Buffer.get_low_latency_rdma_size_hint(
num_max_dispatch_tokens_per_rank, num_max_dispatch_tokens_per_rank,
@@ -181,7 +171,7 @@ class DeepEPBuffer:
).multi_processor_count ).multi_processor_count
if ( if (
(deepep_mode != DeepEPMode.LOW_LATENCY) (deepep_mode != DeepEPMode.LOW_LATENCY)
and not global_server_args_dict["enable_two_batch_overlap"] and not is_tbo_enabled()
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2) and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
): ):
logger.warning( logger.warning(
@@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig):
_instance = None _instance = None
def __init__(self): def __init__(self):
config_str = global_server_args_dict["deepep_config"] config_str = get_deepep_config()
if config_str: if config_str:
config_parsed = load_json_config(config_str) config_parsed = load_json_config(config_str)
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:

View File

@@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple):
@property @property
def format(self) -> DispatchOutputFormat: def format(self) -> DispatchOutputFormat:
return DispatchOutputFormat.standard return DispatchOutputFormat.STANDARD
assert isinstance(StandardDispatchOutput, DispatchOutput) assert isinstance(StandardDispatchOutput, DispatchOutput)

View File

@@ -14,9 +14,18 @@
from __future__ import annotations from __future__ import annotations
import logging
import math import math
from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable from typing import (
Callable,
NamedTuple,
Optional,
Protocol,
TypeGuard,
runtime_checkable,
)
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo, ExpertLocationDispatchInfo,
topk_ids_logical_to_physical, topk_ids_logical_to_physical,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.layers.moe import (
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
@@ -43,6 +55,7 @@ try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError: except ImportError:
pass pass
logger = logging.getLogger(__name__)
_is_cuda = is_cuda() _is_cuda = is_cuda()
@@ -65,13 +78,48 @@ if _use_aiter:
if _is_npu: if _is_npu:
import torch_npu import torch_npu
# -------------------------------- TopKConfig ---------------------------------------
@dataclass
class TopKConfig:
top_k: int
use_grouped_topk: bool = False
topk_group: int = 0
num_expert_group: int = 0
renormalize: bool = True
num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None
correction_bias: Optional[torch.Tensor] = None
torch_native: bool = False
routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False
# -------------------------------- TopKOutput --------------------------------------- # -------------------------------- TopKOutput ---------------------------------------
class TopKOutputChecker:
@staticmethod
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
return topk_output.format.is_standard()
@staticmethod
def format_is_triton_kernel(
topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]:
return topk_output.format.is_triton_kernel()
@staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
return topk_output.format.is_bypassed()
class TopKOutputFormat(Enum): class TopKOutputFormat(Enum):
STANDARD = auto() STANDARD = auto()
TRITON_KERNEL = auto() TRITON_KERNEL = auto()
BYPASSED = auto()
def is_standard(self) -> bool: def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD return self == TopKOutputFormat.STANDARD
@@ -79,6 +127,9 @@ class TopKOutputFormat(Enum):
def is_triton_kernel(self) -> bool: def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL return self == TopKOutputFormat.TRITON_KERNEL
def is_bypassed(self) -> bool:
return self == TopKOutputFormat.BYPASSED
@runtime_checkable @runtime_checkable
class TopKOutput(Protocol): class TopKOutput(Protocol):
@@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple):
return TopKOutputFormat.TRITON_KERNEL return TopKOutputFormat.TRITON_KERNEL
class BypassedTopKOutput(NamedTuple):
"""Bypassed top-k output format."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
topk_config: TopKConfig
num_token_non_padded: Optional[torch.Tensor] = None
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.BYPASSED
# -------------------------------- TopK --------------------------------------- # -------------------------------- TopK ---------------------------------------
@@ -124,8 +189,8 @@ class TopK(CustomOp):
top_k: int, top_k: int,
*, *,
use_grouped_topk: bool = False, use_grouped_topk: bool = False,
topk_group: Optional[int] = None, topk_group: int = 0,
num_expert_group: Optional[int] = None, num_expert_group: int = 0,
renormalize: bool = True, renormalize: bool = True,
num_fused_shared_experts: int = 0, num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
@@ -136,19 +201,23 @@ class TopK(CustomOp):
# NOTE: scoring_func is not used for now, but we keep it for future use # NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details # see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__() super().__init__()
if use_grouped_topk: if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None assert num_expert_group is not None and topk_group is not None
self.top_k = top_k
self.use_grouped_topk = use_grouped_topk
self.renormalize = renormalize
self.topk_group = topk_group
self.num_expert_group = num_expert_group
self.num_fused_shared_experts = num_fused_shared_experts
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.routed_scaling_factor = routed_scaling_factor
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
)
self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
def forward_native( def forward_native(
self, self,
@@ -158,20 +227,11 @@ class TopK(CustomOp):
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput: ) -> TopKOutput:
torch_native = True self.topk_config.torch_native = True
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
@@ -187,24 +247,28 @@ class TopK(CustomOp):
if self.use_triton_kernels: if self.use_triton_kernels:
# renormalize=True is equivalent to sm_first=False # renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing( routing_data, gather_idx, scatter_idx = routing(
router_logits, self.top_k, sm_first=not self.renormalize router_logits,
self.topk_config.top_k,
sm_first=not self.topk_config.renormalize,
) )
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
return BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else: else:
torch_native = False self.topk_config.torch_native = False
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
@@ -220,15 +284,7 @@ class TopK(CustomOp):
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
@@ -244,35 +300,29 @@ class TopK(CustomOp):
global_num_experts = router_logits.shape[-1] global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256: if global_num_experts == 256 and self.topk_config.renormalize is False:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32) router_logits = router_logits.to(torch.float32)
return torch_npu.npu_moe_gating_top_k( return torch_npu.npu_moe_gating_top_k(
router_logits, router_logits,
k=self.top_k, k=self.topk_config.top_k,
bias=self.correction_bias.to(torch.float32), bias=self.topk_config.correction_bias.to(torch.float32),
k_group=self.topk_group, k_group=self.topk_config.topk_group,
group_count=self.num_expert_group, group_count=self.topk_config.num_expert_group,
group_select_mode=1, group_select_mode=1,
renorm=0, renorm=0,
norm_type=1, norm_type=1,
routed_scaling_factor=1, routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20), eps=float(1e-20),
) )
else: else:
torch_native = True self.topk_config.torch_native = True
return select_experts( return select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=self.top_k, topk_config=self.topk_config,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
num_fused_shared_experts=self.num_fused_shared_experts,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
torch_native=torch_native,
routed_scaling_factor=self.routed_scaling_factor,
num_token_non_padded=num_token_non_padded, num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info, expert_location_dispatch_info=expert_location_dispatch_info,
) )
@@ -670,20 +720,23 @@ else:
def select_experts( def select_experts(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, topk_config: TopKConfig,
*, *,
use_grouped_topk: bool = False,
renormalize: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput: ) -> StandardTopKOutput:
top_k = topk_config.top_k
use_grouped_topk = topk_config.use_grouped_topk
topk_group = topk_config.topk_group
num_expert_group = topk_config.num_expert_group
renormalize = topk_config.renormalize
num_fused_shared_experts = topk_config.num_fused_shared_experts
custom_routing_function = topk_config.custom_routing_function
correction_bias = topk_config.correction_bias
torch_native = topk_config.torch_native
routed_scaling_factor = topk_config.routed_scaling_factor
router_logits, correction_bias = ( router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs( expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits, router_logits=router_logits,

View File

@@ -1,55 +1,80 @@
from __future__ import annotations
import importlib.util import importlib.util
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from typing import TYPE_CHECKING, Optional
from packaging import version as pkg_version from packaging import version as pkg_version
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import logger
if TYPE_CHECKING:
@lru_cache(maxsize=1) from sglang.srt.server_args import ServerArgs
def should_use_flashinfer_trtllm_moe():
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result
class MoeA2ABackend(Enum): class MoeA2ABackend(Enum):
STANDARD = ("standard", "none") NONE = "none"
DEEPEP = "deepep" DEEPEP = "deepep"
@classmethod @classmethod
def _missing_(cls, value): def _missing_(cls, value):
if value is None: if value is None:
return cls.STANDARD return cls.NONE
for member in cls: for member in cls:
if value in member.value: if value == member.value:
return member return member
raise ValueError(f"No {cls.__name__} member for value {value}") raise ValueError(f"No {cls.__name__} member for value {value}")
def is_none(self):
return self == MoeA2ABackend.NONE
def is_deepep(self): def is_deepep(self):
return self == MoeA2ABackend.DEEPEP return self == MoeA2ABackend.DEEPEP
def is_standard(self):
return self == MoeA2ABackend.STANDARD class MoeRunnerBackend(Enum):
AUTO = "auto"
TRITON = "triton"
TRITON_KERNEL = "triton_kernel"
FLASHINFER = "flashinfer_trtllm"
FLASHINFER_CUTLASS = "flashinfer_cutlass"
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
def is_auto(self):
return self == MoeRunnerBackend.AUTO
def is_triton(self):
return self == MoeRunnerBackend.TRITON
def is_triton_kernel(self):
return self == MoeRunnerBackend.TRITON_KERNEL
def is_flashinfer_trtllm(self):
return self == MoeRunnerBackend.FLASHINFER
def is_flashinfer_cutlass(self):
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
def is_flashinfer_mxfp4(self):
return self == MoeRunnerBackend.FLASHINFER_MXFP4
class DeepEPMode(Enum): class DeepEPMode(Enum):
NORMAL = "normal" NORMAL = "normal"
LOW_LATENCY = "low_latency" LOW_LATENCY = "low_latency"
AUTO = "auto" AUTO = "auto"
def enable_normal(self): def enable_normal(self) -> bool:
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO] return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
def enable_low_latency(self): def enable_low_latency(self) -> bool:
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO] return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
def resolve(self, is_extend_in_batch: bool): def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
if self != DeepEPMode.AUTO: if self != DeepEPMode.AUTO:
return self return self
@@ -57,3 +82,96 @@ class DeepEPMode(Enum):
return DeepEPMode.NORMAL return DeepEPMode.NORMAL
else: else:
return DeepEPMode.LOW_LATENCY return DeepEPMode.LOW_LATENCY
def is_normal(self) -> bool:
return self == DeepEPMode.NORMAL
def is_low_latency(self) -> bool:
return self == DeepEPMode.LOW_LATENCY
def is_auto(self) -> bool:
return self == DeepEPMode.AUTO
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED: Optional[bool] = None
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
DEEPEP_CONFIG: Optional[str] = None
def initialize_moe_config(server_args: ServerArgs):
global MOE_A2A_BACKEND
global MOE_RUNNER_BACKEND
global DEEPEP_MODE
global DEEPEP_CONFIG
global IS_TBO_ENABLED
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
DEEPEP_CONFIG = server_args.deepep_config or ""
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
def get_moe_a2a_backend() -> MoeA2ABackend:
global MOE_A2A_BACKEND
if MOE_A2A_BACKEND is None:
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
MOE_A2A_BACKEND = MoeA2ABackend(None)
return MOE_A2A_BACKEND
def get_moe_runner_backend() -> MoeRunnerBackend:
global MOE_RUNNER_BACKEND
if MOE_RUNNER_BACKEND is None:
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
return MOE_RUNNER_BACKEND
def get_deepep_mode() -> DeepEPMode:
global DEEPEP_MODE
if DEEPEP_MODE is None:
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
DEEPEP_MODE = DeepEPMode("auto")
return DEEPEP_MODE
def get_deepep_config() -> str:
global DEEPEP_CONFIG
if DEEPEP_CONFIG is None:
logger.warning("DEEPEP_CONFIG is not initialized, using default config")
DEEPEP_CONFIG = ""
return DEEPEP_CONFIG
def is_tbo_enabled() -> bool:
global IS_TBO_ENABLED
if IS_TBO_ENABLED is None:
logger.warning("IS_TBO_ENABLED is not initialized, using False")
IS_TBO_ENABLED = False
return IS_TBO_ENABLED
def get_tbo_token_distribution_threshold() -> float:
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
logger.warning(
"TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
)
TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
return TBO_TOKEN_DISTRIBUTION_THRESHOLD
@lru_cache(maxsize=1)
def should_use_flashinfer_trtllm_moe():
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
not importlib.util.find_spec("flashinfer")
or pkg_version.parse(__import__("flashinfer").__version__)
>= pkg_version.parse("0.2.9rc1")
)
return result

View File

@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import is_cuda, is_hip
@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert (
assert activation == "silu", "Only SiLU activation is supported." moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype

View File

@@ -9,6 +9,7 @@ import torch
from torch import nn from torch import nn
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError

View File

@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn import Module from torch.nn import Module
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True, use_int8_w8a8=True,
w1_scale=(layer.w13_weight_scale_inv), w1_scale=(layer.w13_weight_scale_inv),
w2_scale=(layer.w2_weight_scale_inv), w2_scale=(layer.w2_weight_scale_inv),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )

View File

@@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig, CompressedTensorsConfig,
@@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton import fused_experts from sglang.srt.layers.moe.fused_moe_triton import fused_experts
@@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=self.weight_quant.strategy per_channel_quant=self.weight_quant.strategy
== QuantizationStrategy.CHANNEL, == QuantizationStrategy.CHANNEL,
@@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
) )
@@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
topk_weights, topk_ids, router_logits = topk_output topk_weights, topk_ids, router_logits = topk_output

View File

@@ -41,6 +41,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase):
return out return out
class MxFp4MoEMethod: class MxFp4MoEMethod(FusedMoEMethodBase):
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"): def __init__(self, quant_config: Mxfp4Config):
original_init = cls.__init__ self.quant_config = quant_config
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
@staticmethod @staticmethod
def get_moe_method( def get_moe_method(
@@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
@@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
activation=( activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
), ),
doweight_stage1=False, doweight_stage1=False,
) )
@@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
@@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod):
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
activation=( activation=(
ActivationType.Silu if activation == "silu" else ActivationType.Gelu ActivationType.Silu
if moe_runner_config.activation == "silu"
else ActivationType.Gelu
), ),
doweight_stage1=False, doweight_stage1=False,
) )

View File

@@ -79,6 +79,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
@@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
@@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer, layer,
x, x,
topk_output, topk_output,
activation, moe_runner_config.activation,
no_combine, moe_runner_config.no_combine,
) )
if ret is not None: if ret is not None:
return ret return ret
@@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
use_fp8_blockscale=True, use_fp8_blockscale=True,
) )
# TODO: Fuse into select_experts # TODO: Fuse into select_experts
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
@@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace and not no_combine, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale=( w1_scale=(
layer.w13_weight_scale_inv layer.w13_weight_scale_inv
@@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size, block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
def apply_with_router_logits( def apply_with_router_logits(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
router_logits: torch.Tensor, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
activation = moe_runner_config.activation
routed_scaling_factor = moe_runner_config.routed_scaling_factor
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
from sglang.srt.layers.moe.topk import TopKOutputChecker
assert TopKOutputChecker.format_is_bypassed(topk_output)
router_logits = topk_output.router_logits
topk_config = topk_output.topk_config
assert ( assert (
activation == "silu" activation == "silu"
), "Only silu is supported for flashinfer blockscale fp8 moe" ), "Only silu is supported for flashinfer blockscale fp8 moe"
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
# NOTE: scales of hidden states have to be transposed! # NOTE: scales of hidden states have to be transposed!
a_sf_t = a_sf.t().contiguous() a_sf_t = a_sf.t().contiguous()
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
return trtllm_fp8_block_scale_moe( return trtllm_fp8_block_scale_moe(
routing_logits=router_logits.to(torch.float32), routing_logits=router_logits.to(torch.float32),
@@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase):
gemm2_weights=layer.w2_weight, gemm2_weights=layer.w2_weight,
gemm2_weights_scale=layer.w2_weight_scale_inv, gemm2_weights_scale=layer.w2_weight_scale_inv,
num_experts=layer.num_experts, num_experts=layer.num_experts,
top_k=layer.top_k, top_k=topk_config.top_k,
n_group=layer.num_expert_group, n_group=topk_config.num_expert_group,
topk_group=layer.topk_group, topk_group=topk_config.topk_group,
intermediate_size=layer.w2_weight.shape[2], intermediate_size=layer.w2_weight.shape[2],
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
local_num_experts=layer.num_local_experts, local_num_experts=layer.num_local_experts,

View File

@@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz(
return weight, weight_scale, input_scale return weight, weight_scale, input_scale
# TODO(ch-wan): define these backends in --moe-runner-backend
def cutlass_block_fp8_supported() -> bool: def cutlass_block_fp8_supported() -> bool:
if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"): if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"):
return False return False

View File

@@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
@@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# Delay the import to avoid circular dependency # Delay the import to avoid circular dependency
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# The input must currently be float16 # The input must currently be float16
orig_dtype = x.dtype orig_dtype = x.dtype

View File

@@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
@@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
)[0] )[0]
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool:
hidden_size = layer.hidden_size hidden_size = layer.hidden_size
intermediate_size_per_partition = layer.intermediate_size_per_partition intermediate_size_per_partition = layer.intermediate_size_per_partition
# apply_router_weight_on_input is not supported for moe marlin # apply_router_weight_on_input is not supported for moe marlin
supports_router_weight = not layer.apply_router_weight_on_input supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input
# moe marlin requires the activation to be silu # moe marlin requires the activation to be silu
supports_activation = layer.activation == "silu" supports_activation = layer.moe_runner_config.activation == "silu"
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
# down: (n, k) = (hidden_size, intermediate_size_per_partition) # down: (n, k) = (hidden_size, intermediate_size_per_partition)

View File

@@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
@@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import is_cuda, next_power_of_2 from sglang.srt.utils import is_cuda, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
if is_cuda(): if is_cuda():
@@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=False, # ModelOpt uses per-tensor quantization per_channel_quant=False, # ModelOpt uses per-tensor quantization
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale, w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
no_combine=no_combine,
) )
@@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
@property @property
def enable_flashinfer_cutlass_moe(self) -> bool: def enable_flashinfer_cutlass_moe(self) -> bool:
from sglang.srt.layers.moe import get_moe_runner_backend
"""Access the global enable_flashinfer_cutlass_moe setting.""" """Access the global enable_flashinfer_cutlass_moe setting."""
return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False) return get_moe_runner_backend().is_flashinfer_cutlass()
def create_weights( def create_weights(
self, self,
@@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: FusedMoE,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
# Check if this is a FlashInferFP4MoE layer that should handle its own forward # Check if this is a FlashInferFP4MoE layer that should handle its own forward
if hasattr(layer, "gemm1_weights_fp4_shuffled"): if hasattr(layer, "gemm1_weights_fp4_shuffled"):
@@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
if self.enable_flashinfer_cutlass_moe: if self.enable_flashinfer_cutlass_moe:
assert ( assert (
not apply_router_weight_on_input not moe_runner_config.apply_router_weight_on_input
), "apply_router_weight_on_input is not supported for Flashinfer" ), "apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint # and fp4 quantized weights loaded from the checkpoint
@@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_blockscale_swizzled.view(torch.int32), layer.w2_blockscale_swizzled.view(torch.int32),
layer.g2_alphas, layer.g2_alphas,
], ],
ep_size=ep_size, ep_size=layer.moe_ep_size,
ep_rank=ep_rank, ep_rank=layer.moe_ep_rank,
tp_size=tp_size, tp_size=layer.moe_tp_size,
tp_rank=tp_rank, tp_rank=layer.moe_tp_rank,
tune_max_num_tokens=next_power_of_2(x.shape[0]), tune_max_num_tokens=next_power_of_2(x.shape[0]),
)[0] )[0]
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
@@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
params=layer.cutlass_moe_params, params=layer.cutlass_moe_params,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
).to(x.dtype) ).to(x.dtype)
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output

View File

@@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
@@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# avoid circular import # avoid circular import
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported." assert (
moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
weight_bits = self.quant_config.weight_bits weight_bits = self.quant_config.weight_bits
has_zp = self.quant_config.has_zp has_zp = self.quant_config.has_zp
@@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
layer.w13_qweight, layer.w13_qweight,
layer.w2_qweight, layer.w2_qweight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int4_w4a16=weight_bits == 4, use_int4_w4a16=weight_bits == 4,
use_int8_w8a16=weight_bits == 8, use_int8_w8a16=weight_bits == 8,
w1_scale=layer.w13_scales, w1_scale=layer.w13_scales,
@@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase):
w1_zp=layer.w13_qzeros if has_zp else None, w1_zp=layer.w13_qzeros if has_zp else None,
w2_zp=layer.w2_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size], block_shape=[0, layer.group_size],
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
@staticmethod @staticmethod
@@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase):
) )
if "w13_qzeros" in weight_name: if "w13_qzeros" in weight_name:
tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[ tensor = loaded_weight.view(
tp_rank layer.moe_tp_size, -1, loaded_weight.size(1)
] )[tp_rank]
if shard_id == "w1": if shard_id == "w1":
param.data[expert_id, : shard_size // 2] = tensor param.data[expert_id, : shard_size // 2] = tensor
else: else:
param.data[expert_id, shard_size // 2 :] = tensor param.data[expert_id, shard_size // 2 :] = tensor
elif "w2_qzeros" in weight_name: elif "w2_qzeros" in weight_name:
param.data[expert_id] = loaded_weight.view( param.data[expert_id] = loaded_weight.view(
loaded_weight.size(0), layer.tp_size, -1 loaded_weight.size(0), layer.moe_tp_size, -1
)[:, tp_rank] )[:, tp_rank]
else: else:
weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)

View File

@@ -16,14 +16,13 @@
from __future__ import annotations from __future__ import annotations
import importlib.util
import logging import logging
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import triton.language as tl
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from sglang.srt.layers.moe.utils import get_moe_runner_backend
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
FusedMoEMethodBase, FusedMoEMethodBase,
QuantizationConfig, QuantizationConfig,
@@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import (
) )
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.utils import ( from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
@@ -60,6 +58,7 @@ if is_flashinfer_available():
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
OCP_MX_BLOCK_SIZE = 32 OCP_MX_BLOCK_SIZE = 32
@@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self, self,
prefix: str, prefix: str,
): ):
from sglang.srt.managers.schedule_batch import global_server_args_dict
super().__init__() super().__init__()
self.prefix = prefix self.prefix = prefix
self.topk_indices_dtype = None self.topk_indices_dtype = None
self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel()
self.with_bias = False self.with_bias = False
self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"] self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4()
self.triton_kernel_moe_forward = None self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None self.triton_kernel_moe_with_bias_forward = None
@@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
logger, logger,
f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...", f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...",
) )
# TODO: these values are hardcoded for now, we need to get them from the model
layer.gemm1_alpha = Parameter( layer.gemm1_alpha = Parameter(
torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False, requires_grad=False,
@@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_flashinfer: if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
@@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
b1=layer.w13_weight_bias, b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias, b2=layer.w2_weight_bias,
topk_output=topk_output, topk_output=topk_output,
activation=activation, moe_runner_config=moe_runner_config,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
else: else:
return self.triton_kernel_moe_forward( return self.triton_kernel_moe_forward(
@@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config,
) )
else: else:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config,
b1=layer.w13_weight_bias, b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias, b2=layer.w2_weight_bias,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import importlib import importlib.util
from typing import TYPE_CHECKING, Callable, List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -24,7 +24,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
@@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
kwargs = {}
if activation_alpha is not None:
kwargs["activation_alpha"] = activation_alpha
if swiglu_limit is not None:
kwargs["swiglu_limit"] = swiglu_limit
return self.forward( return self.forward(
x=x, x=x,
layer=layer, layer=layer,
topk_output=topk_output, topk_output=topk_output,
activation=activation, moe_runner_config=moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
**kwargs,
) )
def forward_cuda( def forward_cuda(
@@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if self.use_triton_kernels: if self.use_triton_kernels:
if self.with_bias: if self.with_bias:
assert self.triton_kernel_moe_with_bias_forward is not None
return self.triton_kernel_moe_with_bias_forward( return self.triton_kernel_moe_with_bias_forward(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=layer.w13_weight_bias, b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias, b2=layer.w2_weight_bias,
topk_output=topk_output, topk_output=topk_output,
activation=activation, moe_runner_config=moe_runner_config,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
w1_pcg=None, w1_pcg=None,
w2_pcg=None, w2_pcg=None,
) )
else: else:
assert self.triton_kernel_moe_forward is not None
return self.triton_kernel_moe_forward( return self.triton_kernel_moe_forward(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
moe_runner_config=moe_runner_config,
) )
else: else:
if _use_aiter: if _use_aiter:
assert not no_combine, "unsupported" assert not moe_runner_config.no_combine, "unsupported"
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
if apply_router_weight_on_input: if moe_runner_config.apply_router_weight_on_input:
assert ( assert (
topk_weights.dim() == 2 topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)" ), "`topk_weights` should be in shape (num_tokens, topk)"
@@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_ids, topk_ids,
activation=( activation=(
ActivationType.Silu ActivationType.Silu
if activation == "silu" if moe_runner_config.activation == "silu"
else ActivationType.Gelu else ActivationType.Gelu
), ),
) )
@@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
b1=getattr(layer, "w13_weight_bias", None), b1=getattr(layer, "w13_weight_bias", None),
b2=getattr(layer, "w2_weight_bias", None), b2=getattr(layer, "w2_weight_bias", None),
topk_output=topk_output, topk_output=topk_output,
inplace=inplace and not no_combine, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
) )
def forward_cpu( def forward_cpu(
@@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert activation == "silu", f"activation = {activation} is not supported." assert (
moe_runner_config.activation == "silu"
), f"activation = {moe_runner_config.activation} is not supported."
if use_intel_amx_backend(layer) and not apply_router_weight_on_input: if (
use_intel_amx_backend(layer)
and not moe_runner_config.apply_router_weight_on_input
):
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
@@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer, layer,
x, x,
topk_output, topk_output,
activation=activation, moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
def forward_npu( def forward_npu(
@@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
@@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer, layer,
x, x,
topk_output, topk_output,
activation=activation, moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
def forward_tpu(self, *args, **kwargs) -> torch.Tensor: def forward_tpu(self, *args, **kwargs) -> torch.Tensor:

View File

@@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.topk import StandardTopKOutput
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
self, self,
layer: EPMoE, layer: EPMoE,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
activation: str = "silu", moe_runner_config: MoeRunnerConfig,
apply_router_weight_on_input: bool = False,
routed_scaling_factor: Optional[float] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO(ch-wan): move it out of this class # TODO(ch-wan): move it out of this class
@@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
layer.w13_input_scale, layer.w13_input_scale,
layer.w2_input_scale, layer.w2_input_scale,
) )
if routed_scaling_factor is not None: if moe_runner_config.routed_scaling_factor is not None:
output *= routed_scaling_factor output *= moe_runner_config.routed_scaling_factor
return output return output

View File

@@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import StandardTopKOutput
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
@@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: StandardTopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
apply_router_weight_on_input=apply_router_weight_on_input,
activation=activation,
use_fp8_w8a8=True, use_fp8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )

View File

@@ -49,6 +49,7 @@ from sglang.srt.utils import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.moe.topk import TopKOutput
_is_cuda = is_cuda() _is_cuda = is_cuda()
@@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
topk_output: TopKOutput, topk_output: TopKOutput,
*, moe_runner_config: MoeRunnerConfig,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor: ) -> torch.Tensor:
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
@@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
x, topk_weights = apply_topk_weights_cpu( x, topk_weights = apply_topk_weights_cpu(
apply_router_weight_on_input, topk_weights, x moe_runner_config.apply_router_weight_on_input, topk_weights, x
) )
return torch.ops.sgl_kernel.fused_experts_cpu( return torch.ops.sgl_kernel.fused_experts_cpu(
x, x,
@@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
layer.w13_weight, layer.w13_weight,
layer.w2_weight, layer.w2_weight,
topk_output=topk_output, topk_output=topk_output,
inplace=inplace, moe_runner_config=moe_runner_config,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True, use_int8_w8a8=True,
per_channel_quant=True, per_channel_quant=True,
w1_scale=(layer.w13_weight_scale), w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale), w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
) )
@@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
layer, layer,
x, x,
topk_output: TopKOutput, topk_output: TopKOutput,
**kwargs, moe_runner_config: MoeRunnerConfig,
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output

View File

@@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
ScheduleBatchDisaggregationDecodeMixin, ScheduleBatchDisaggregationDecodeMixin,
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.moe import is_tbo_enabled
from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator, BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator, SWATokenToKVPoolAllocator,
@@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
"device", "device",
"disable_chunked_prefix_cache", "disable_chunked_prefix_cache",
"disable_radix_cache", "disable_radix_cache",
"enable_two_batch_overlap",
"tbo_token_distribution_threshold",
"enable_dp_lm_head", "enable_dp_lm_head",
"moe_a2a_backend",
"deepep_mode",
"enable_flashinfer_cutlass_moe",
"enable_flashinfer_trtllm_moe",
"enable_flashinfer_allreduce_fusion", "enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size", "moe_dense_tp_size",
"ep_dispatch_algorithm", "ep_dispatch_algorithm",
"deepep_config",
"ep_num_redundant_experts", "ep_num_redundant_experts",
"enable_nan_detection", "enable_nan_detection",
"flashinfer_mla_disable_ragged", "flashinfer_mla_disable_ragged",
@@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"triton_attention_reduce_in_fp32", "triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
"weight_loader_disable_mmap", "weight_loader_disable_mmap",
"enable_triton_kernel_moe",
"enable_flashinfer_mxfp4_moe",
"enable_multimodal", "enable_multimodal",
"enable_symm_mem", "enable_symm_mem",
"quantization", "quantization",

View File

@@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend from sglang.srt.layers.moe import initialize_moe_config
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
@@ -245,6 +245,9 @@ class Scheduler(
) )
) )
# Init model config
self.model_config = ModelConfig.from_server_args(server_args)
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
self.idle_sleeper = None self.idle_sleeper = None
@@ -292,6 +295,9 @@ class Scheduler(
# Init tokenizer # Init tokenizer
self.init_tokenizer() self.init_tokenizer()
# Init moe config
self.init_moe_config()
# Set reasoning_parser and think_end_id if --reasoning_parser is enabled # Set reasoning_parser and think_end_id if --reasoning_parser is enabled
if self.server_args.reasoning_parser and self.tokenizer: if self.server_args.reasoning_parser and self.tokenizer:
reasoning_parser = ReasoningParser( reasoning_parser = ReasoningParser(
@@ -538,8 +544,6 @@ class Scheduler(
def init_tokenizer(self): def init_tokenizer(self):
server_args = self.server_args server_args = self.server_args
self.model_config = ModelConfig.from_server_args(server_args)
self.is_generation = self.model_config.is_generation self.is_generation = self.model_config.is_generation
if server_args.skip_tokenizer_init: if server_args.skip_tokenizer_init:
@@ -761,6 +765,10 @@ class Scheduler(
# The prefill requests that are in the middle of kv sending # The prefill requests that are in the middle of kv sending
self.disagg_prefill_inflight_queue: List[Req] = [] self.disagg_prefill_inflight_queue: List[Req] = []
def init_moe_config(self):
if hasattr(self.model_config.hf_config, "num_experts_per_tok"):
initialize_moe_config(self.server_args)
@DynamicGradMode() @DynamicGradMode()
def event_loop_normal(self): def event_loop_normal(self):
"""A normal scheduler loop.""" """A normal scheduler loop."""
@@ -1823,11 +1831,6 @@ class Scheduler(
disable_cuda_graph=self.server_args.disable_cuda_graph, disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=MoeA2ABackend(
self.server_args.moe_a2a_backend
).is_deepep(),
deepep_mode=DeepEPMode(self.server_args.deepep_mode),
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule, disable_overlap_schedule=self.server_args.disable_overlap_schedule,
) )
@@ -1922,9 +1925,6 @@ class Scheduler(
disable_cuda_graph: bool, disable_cuda_graph: bool,
spec_algorithm, spec_algorithm,
speculative_num_draft_tokens, speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool, require_mlp_tp_gather: bool,
disable_overlap_schedule: bool, disable_overlap_schedule: bool,
): ):
@@ -1972,9 +1972,6 @@ class Scheduler(
is_extend_in_batch, is_extend_in_batch,
*tbo_preparer.prepare_all_gather( *tbo_preparer.prepare_all_gather(
local_batch, local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
), ),
], ],
dtype=torch.int64, dtype=torch.int64,

View File

@@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import (
initialize_dp_attention, initialize_dp_attention,
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
from sglang.srt.layers.quantization import ( from sglang.srt.layers.quantization import (
deep_gemm_wrapper, deep_gemm_wrapper,
monkey_patch_isinstance_for_vllm_base_layer, monkey_patch_isinstance_for_vllm_base_layer,
@@ -219,8 +218,6 @@ class ModelRunner:
# TODO it is indeed not a "server args" # TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend, "use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm, "speculative_algorithm": self.spec_algorithm,
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
} }
) )

View File

@@ -32,7 +32,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
@@ -104,6 +106,11 @@ class DbrxExperts(nn.Module):
self.params_dtype = params_dtype self.params_dtype = params_dtype
self.router = DbrxRouter(config, self.params_dtype) self.router = DbrxRouter(config, self.params_dtype)
self.topk = TopK(
self.top_k,
renormalize=True,
)
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.ws = nn.Parameter( self.ws = nn.Parameter(
torch.empty( torch.empty(
self.num_total_experts, self.num_total_experts,
@@ -169,14 +176,13 @@ class DbrxExperts(nn.Module):
hidden_states = hidden_states.view(-1, self.d_model) hidden_states = hidden_states.view(-1, self.d_model)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.router(hidden_states) router_logits = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe( final_hidden_states = fused_moe(
hidden_states, hidden_states,
self.ws, self.ws,
self.w2s, self.w2s,
router_logits, topk_output,
self.top_k, self.moe_runner_config,
renormalize=True,
inplace=True,
) )
if self.tp_size > 1: if self.tp_size > 1:
@@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module):
position_ids: torch.Tensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> Tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states residual = hidden_states
hidden_states = self.norm_1(hidden_states) hidden_states = self.norm_1(hidden_states)
x = self.attn( x = self.attn(

View File

@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
@@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module):
w1=self.w1, w1=self.w1,
w2=self.w2, w2=self.w2,
topk_output=topk_output, topk_output=topk_output,
inplace=True, moe_runner_config=MoeRunnerConfig(inplace=True),
) )
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:

View File

@@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
@@ -61,9 +60,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
@@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module):
quant_config=quant_config, quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
) )
self.shared_experts_is_int8 = False self.shared_experts_is_int8 = False
@@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module):
prefix=add_prefix("shared_experts", prefix), prefix=add_prefix("shared_experts", prefix),
**( **(
dict(tp_rank=0, tp_size=1) dict(tp_rank=0, tp_size=1)
if global_server_args_dict["moe_a2a_backend"].is_deepep() if get_moe_a2a_backend().is_deepep()
else {} else {}
), ),
) )
@@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
@@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"], deepep_mode=get_deepep_mode(),
async_finish=True, async_finish=True,
return_recv_hook=True, return_recv_hook=True,
) )
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def get_moe_weights(self): def get_moe_weights(self):
return [ return [
@@ -484,12 +460,6 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
@@ -520,12 +490,6 @@ class DeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
# FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple
# Regular FusedMoE (CUTLASS path) expects StandardTopKOutput
if should_use_flashinfer_trtllm_moe():
kwargs["topk_output"] = (self.topk, router_logits)
else:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
@@ -2478,18 +2442,16 @@ class DeepseekV2ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
) )
if self.quant_config and self.quant_config.get_name() == "w4afp8": if self.quant_config and self.quant_config.get_name() == "w4afp8":
expert_params_mapping += ( expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
get_moe_impl_class().make_expert_input_scale_params_mapping(
num_experts=self.config.n_routed_experts num_experts=self.config.n_routed_experts
) )
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (

View File

@@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP
@@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module):
class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM): class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",

View File

@@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
@@ -51,9 +50,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
is_fp8_fnuz, is_fp8_fnuz,
@@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import (
DeepseekV2Model, DeepseekV2Model,
DeepseekV2MoE, DeepseekV2MoE,
) )
from sglang.srt.two_batch_overlap import ( from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
MaybeTboDeepEPDispatcher,
model_forward_maybe_tbo,
)
from sglang.srt.utils import ( from sglang.srt.utils import (
BumpAllocator, BumpAllocator,
LazyValue, LazyValue,
@@ -414,8 +411,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
) )
self.topk = ( self.topk = TopK(
TopK(
top_k=config.num_experts_per_tok + self.num_fused_shared_experts, top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
renormalize=config.norm_topk_prob, renormalize=config.norm_topk_prob,
use_grouped_topk=True, use_grouped_topk=True,
@@ -425,9 +421,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
correction_bias=self.gate.e_score_correction_bias, correction_bias=self.gate.e_score_correction_bias,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
) )
if not should_use_flashinfer_trtllm_moe()
else None
)
self.experts = get_moe_impl_class()( self.experts = get_moe_impl_class()(
num_experts=config.n_routed_experts num_experts=config.n_routed_experts
@@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config=quant_config, quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
**(
dict(
renormalize=config.norm_topk_prob,
use_grouped_topk=True,
num_expert_group=config.n_group,
num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
if should_use_flashinfer_trtllm_moe()
else {}
),
) )
self.shared_experts_is_int8 = False self.shared_experts_is_int8 = False
@@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self.top_k = config.num_experts_per_tok self.top_k = config.num_experts_per_tok
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
@@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
num_local_experts=config.n_routed_experts // self.tp_size, num_local_experts=config.n_routed_experts // self.tp_size,
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
params_dtype=config.torch_dtype, params_dtype=config.torch_dtype,
deepep_mode=global_server_args_dict["deepep_mode"], deepep_mode=get_deepep_mode(),
async_finish=True, async_finish=True,
return_recv_hook=True, return_recv_hook=True,
) )
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
def forward_normal_dual_stream( def forward_normal_dual_stream(
self, self,
@@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if not _is_cuda: if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor final_hidden_states *= self.routed_scaling_factor
@@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states) router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states} kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits) kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs) final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter: if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here # fused in biased_grouped_topk so we can skip here
@@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
@@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue( self._routed_experts_weights_of_layer = LazyValue(
lambda: { lambda: {
@@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",

View File

@@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_moe_expert_parallel_world_size, get_moe_expert_parallel_world_size,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
) )
from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
@@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
config.moe_layer_freq = 1 config.moe_layer_freq = 1
self.config = config self.config = config
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.dp_size = get_local_attention_dp_size()
self.quant_config = quant_config self.quant_config = quant_config
self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") self.determine_num_fused_shared_experts("Glm4MoeForCausalLM")
self.num_fused_shared_experts = ( self.num_fused_shared_experts = (
@@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# Params for weights, fp8 weight scales, fp8 activation scales # Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id) # (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",

View File

@@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
@@ -50,9 +49,10 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4 from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
@@ -110,12 +110,9 @@ class GptOssSparseMoeBlock(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id self.layer_id = layer_id
self.activation = config.hidden_act self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702) self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit self.gemm1_clamp_limit = config.swiglu_limit
if global_server_args_dict["enable_flashinfer_mxfp4_moe"]:
self.topk = None
else:
self.topk = TopK( self.topk = TopK(
top_k=config.num_experts_per_tok, top_k=config.num_experts_per_tok,
renormalize=True, renormalize=True,
@@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module):
quant_config.get_name() if quant_config is not None else None quant_config.get_name() if quant_config is not None else None
) )
extra_kwargs = { extra_kwargs = {
"enable_flashinfer_cutlass_moe": global_server_args_dict[
"enable_flashinfer_cutlass_moe"
],
# for moe gate_up_proj and down_proj and their bias loading # for moe gate_up_proj and down_proj and their bias loading
"use_weight_loader_fused": quant_config_name != "mxfp4", "use_weight_loader_fused": quant_config_name
!= "mxfp4"
} }
self.experts = experts_type( self.experts = experts_type(
num_experts=config.num_local_experts num_experts=config.num_local_experts
@@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
quant_config=quant_config, quant_config=quant_config,
activation=self.activation, activation=self.activation,
activation_alpha=self.activation_alpha, gemm1_alpha=self.gemm1_alpha,
swiglu_limit=self.swiglu_limit, gemm1_clamp_limit=self.gemm1_clamp_limit,
with_bias=True, with_bias=True,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
**extra_kwargs, **extra_kwargs,
) )
@@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module):
forward_batch: Optional[ForwardBatch] = None, forward_batch: Optional[ForwardBatch] = None,
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep(): if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, should_allreduce_fusion) return self.forward_normal(hidden_states, should_allreduce_fusion)
else: else:
raise Exception("forward_deepep branch not implemented yet") raise Exception("forward_deepep branch not implemented yet")
@@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module):
should_allreduce_fusion: bool = False, should_allreduce_fusion: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
kwargs = {"hidden_states": hidden_states} final_hidden_states = self.experts(hidden_states, topk_output)
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["topk_output"] = (self.top_k, router_logits)
final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1 and not should_allreduce_fusion: if self.tp_size > 1 and not should_allreduce_fusion:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
@@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# GptOss all layers are sparse and have no nextn now # GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True
@@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module):
("qkv_proj", "k_proj", "k"), ("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"), ("qkv_proj", "v_proj", "v"),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused( expert_params_mapping = FusedMoE.make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj", ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias", ckpt_gate_up_proj_bias_name="gate_up_proj_bias",

View File

@@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module):
params_dtype=params_dtype, params_dtype=params_dtype,
reduce_results=True, reduce_results=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
) )

View File

@@ -135,7 +135,6 @@ class Grok1MoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
activation="gelu", activation="gelu",
**kwargs, **kwargs,
) )

View File

@@ -6,6 +6,7 @@ from transformers import PretrainedConfig
from sglang.srt.distributed import parallel_state from sglang.srt.distributed import parallel_state
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
@@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module):
] ]
expert_params_mapping = [] expert_params_mapping = []
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures: if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",

View File

@@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from sglang.srt.distributed import parallel_state from sglang.srt.distributed import parallel_state
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import ( from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs, MultiModalityDataPaddingPatternTokenPairs,
@@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",

View File

@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
rope_theta = config.rope_theta rope_theta = config.rope_theta
rope_scaling = config.rope_scaling rope_scaling = config.rope_scaling
max_position_embeddings = config.max_position_embeddings max_position_embeddings = config.max_position_embeddings
self.local_dp_size = get_local_attention_dp_size()
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()

View File

@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda from sglang.srt.utils import add_prefix, is_cuda

View File

@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers from sglang.srt.utils import add_prefix, make_layers
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )

View File

@@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module):
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
reduce_results=True, reduce_results=True,
quant_config=quant_config, quant_config=quant_config,
tp_size=tp_size,
layer_id=layer_id, layer_id=layer_id,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
) )

View File

@@ -17,8 +17,6 @@
"""Inference-only Qwen2MoE model compatible with HuggingFace weights.""" """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
import logging import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch import torch
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_distribution import ( from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
ExpertDistributionRecorder,
get_global_expert_distribution_recorder,
)
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import ( from sglang.srt.layers.communicator import (
@@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import (
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
is_dp_attention_enabled, is_dp_attention_enabled,
) )
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
@@ -55,8 +49,8 @@ from sglang.srt.layers.linear import (
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
@@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen2MoE all layers are sparse and have no nextn now # Qwen2MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True

View File

@@ -28,50 +28,35 @@ from sglang.srt.distributed import (
get_pp_group, get_pp_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import get_layer_id from sglang.srt.layers.utils import get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty
Qwen3MoeConfig = None Qwen3MoeConfig = None
@@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
intermediate_size=config.moe_intermediate_size, intermediate_size=config.moe_intermediate_size,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("experts", prefix), prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=global_server_args_dict["deepep_mode"])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
# Additional args for FusedMoE
**(
dict(
enable_flashinfer_cutlass_moe=True,
)
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
else {}
),
) )
self.gate = ReplicatedLinear( self.gate = ReplicatedLinear(
@@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
prefix=add_prefix("gate", prefix), prefix=add_prefix("gate", prefix),
) )
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future # TODO: we will support tp < ep in the future
self.ep_size = get_moe_expert_parallel_world_size() self.ep_size = get_moe_expert_parallel_world_size()
self.num_experts = ( self.num_experts = (
@@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
use_reduce_scatter: bool = False, use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep(): if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, use_reduce_scatter) return self.forward_normal(hidden_states, use_reduce_scatter)
else: else:
return self.forward_deepep(hidden_states, forward_batch) return self.forward_deepep(hidden_states, forward_batch)
@@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module):
self.attn_tp_size = get_attention_tp_size() self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank() self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# Qwen3MoE all layers are sparse and have no nextn now # Qwen3MoE all layers are sparse and have no nextn now
self.is_layer_sparse = True self.is_layer_sparse = True
@@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module):
("gate_up_proj", "up_proj", 1), ("gate_up_proj", "up_proj", 1),
] ]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj", ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj", ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj", ckpt_up_proj_name="up_proj",

View File

@@ -38,6 +38,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.moe.topk import TopK
@@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module):
prefix=add_prefix("gate", prefix), prefix=add_prefix("gate", prefix),
) )
if global_server_args_dict["moe_a2a_backend"].is_deepep(): if get_moe_a2a_backend().is_deepep():
raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.") raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@@ -33,7 +33,9 @@ from sglang.srt.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.rotary_embedding import get_rope
@@ -121,6 +123,7 @@ class XverseMoE(nn.Module):
] ]
) )
self.pack_params() self.pack_params()
self.moe_runner_config = MoeRunnerConfig(inplace=True)
self.router = ReplicatedLinear( self.router = ReplicatedLinear(
config.hidden_size, config.hidden_size,
@@ -129,6 +132,10 @@ class XverseMoE(nn.Module):
quant_config=None, quant_config=None,
prefix=add_prefix("router", prefix), prefix=add_prefix("router", prefix),
) )
self.topk = TopK(
top_k=self.top_k,
renormalize=getattr(self.config, "norm_topk_prob", False),
)
if config.num_shared_experts is not None: if config.num_shared_experts is not None:
intermediate_size = config.intermediate_size * config.num_shared_experts intermediate_size = config.intermediate_size * config.num_shared_experts
@@ -167,14 +174,13 @@ class XverseMoE(nn.Module):
shared_output = self.shared_experts(hidden_states) shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states) router_logits, _ = self.router(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = fused_moe( final_hidden_states = fused_moe(
hidden_states, hidden_states,
self.w1, self.w1,
self.w2, self.w2,
router_logits, topk_output,
self.top_k, self.moe_runner_config,
renormalize=getattr(self.config, "norm_topk_prob", False),
inplace=True,
) )
if self.config.num_shared_experts is not None: if self.config.num_shared_experts is not None:

View File

@@ -37,6 +37,7 @@ from sglang.srt.utils import (
is_hip, is_hip,
is_port_available, is_port_available,
is_remote_url, is_remote_url,
is_triton_kernels_available,
is_valid_ipv6_address, is_valid_ipv6_address,
nullable_str, nullable_str,
) )
@@ -175,9 +176,15 @@ class ServerArgs:
# Expert parallelism # Expert parallelism
ep_size: int = 1 ep_size: int = 1
moe_a2a_backend: Optional[Literal["deepep"]] = None moe_a2a_backend: Literal["none", "deepep"] = "none"
enable_flashinfer_cutlass_moe: bool = False moe_runner_backend: Literal[
enable_flashinfer_trtllm_moe: bool = False "auto",
"triton",
"triton_kernel",
"flashinfer_trtllm",
"flashinfer_cutlass",
"flashinfer_mxfp4",
] = "auto"
enable_flashinfer_allreduce_fusion: bool = False enable_flashinfer_allreduce_fusion: bool = False
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
ep_num_redundant_experts: int = 0 ep_num_redundant_experts: int = 0
@@ -250,8 +257,6 @@ class ServerArgs:
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False disable_fast_image_processor: bool = False
enable_return_hidden_states: bool = False enable_return_hidden_states: bool = False
enable_triton_kernel_moe: bool = False
enable_flashinfer_mxfp4_moe: bool = False
scheduler_recv_interval: int = 1 scheduler_recv_interval: int = 1
# Debug tensor dumps # Debug tensor dumps
@@ -282,6 +287,9 @@ class ServerArgs:
# Deprecated arguments # Deprecated arguments
enable_ep_moe: bool = False enable_ep_moe: bool = False
enable_deepep_moe: bool = False enable_deepep_moe: bool = False
enable_flashinfer_cutlass_moe: bool = False
enable_flashinfer_trtllm_moe: bool = False
enable_triton_kernel_moe: bool = False
def __post_init__(self): def __post_init__(self):
# Check deprecated arguments # Check deprecated arguments
@@ -298,6 +306,21 @@ class ServerArgs:
print_deprecated_warning( print_deprecated_warning(
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead." "NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
) )
if self.enable_triton_kernel_moe:
self.moe_runner_backend = "triton_kernel"
print_deprecated_warning(
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
)
if self.enable_flashinfer_cutlass_moe:
self.moe_runner_backend = "flashinfer_cutlass"
print_deprecated_warning(
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
)
if self.enable_flashinfer_trtllm_moe:
self.moe_runner_backend = "flashinfer_trtllm"
print_deprecated_warning(
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
)
# Set missing default values # Set missing default values
if self.tokenizer_path is None: if self.tokenizer_path is None:
@@ -517,7 +540,7 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_lm_head. " ), "Please enable dp attention when setting enable_dp_lm_head. "
# MoE kernel # MoE kernel
if self.enable_flashinfer_cutlass_moe: if self.moe_runner_backend == "flashinfer_cutlass":
assert ( assert (
self.quantization == "modelopt_fp4" self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE" ), "modelopt_fp4 quantization is required for Flashinfer MOE"
@@ -527,7 +550,7 @@ class ServerArgs:
self.tp_size, self.tp_size,
], "The expert parallel size must be 1 or the same as the tensor parallel size" ], "The expert parallel size must be 1 or the same as the tensor parallel size"
if self.enable_flashinfer_trtllm_moe: if self.moe_runner_backend == "flashinfer_trtllm":
if not self.disable_shared_experts_fusion: if not self.disable_shared_experts_fusion:
self.disable_shared_experts_fusion = True self.disable_shared_experts_fusion = True
logger.warning( logger.warning(
@@ -556,7 +579,7 @@ class ServerArgs:
self.ep_dispatch_algorithm = "static" self.ep_dispatch_algorithm = "static"
if self.enable_eplb: if self.enable_eplb:
assert self.ep_size > 1 or self.moe_a2a_backend is not None assert self.ep_size > 1
if self.enable_expert_distribution_metrics and ( if self.enable_expert_distribution_metrics and (
self.expert_distribution_recorder_mode is None self.expert_distribution_recorder_mode is None
@@ -1446,19 +1469,22 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--moe-a2a-backend", "--moe-a2a-backend",
type=str, type=str,
choices=["deepep"], choices=["none", "deepep"],
default=ServerArgs.moe_a2a_backend, default=ServerArgs.moe_a2a_backend,
help="Choose the backend for MoE A2A.", help="Choose the backend for MoE A2A.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-cutlass-moe", "--moe-runner-backend",
action="store_true", type=str,
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", choices=[
) "auto",
parser.add_argument( "triton",
"--enable-flashinfer-trtllm-moe", "triton_kernel",
action="store_true", "flashinfer_trtllm",
help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP", "flashinfer_cutlass",
],
default=ServerArgs.moe_runner_backend,
help="Choose the runner backend for MoE.",
) )
parser.add_argument( parser.add_argument(
"--enable-flashinfer-allreduce-fusion", "--enable-flashinfer-allreduce-fusion",
@@ -1825,11 +1851,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable returning hidden states with responses.", help="Enable returning hidden states with responses.",
) )
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="Use triton moe grouped gemm kernel.",
)
parser.add_argument( parser.add_argument(
"--enable-flashinfer-mxfp4-moe", "--enable-flashinfer-mxfp4-moe",
action="store_true", action="store_true",
@@ -1965,6 +1986,21 @@ class ServerArgs:
action="store_true", action="store_true",
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.", help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
) )
parser.add_argument(
"--enable-flashinfer-cutlass-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
)
parser.add_argument(
"--enable-flashinfer-trtllm-moe",
action="store_true",
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
)
parser.add_argument(
"--enable-triton-kernel-moe",
action="store_true",
help="(Deprecated) Use triton moe grouped gemm kernel.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
@@ -2143,18 +2179,21 @@ class ServerArgs:
) )
if is_sm100_supported() and is_mxfp4_quant_format: if is_sm100_supported() and is_mxfp4_quant_format:
self.enable_flashinfer_mxfp4_moe = True self.moe_runner_backend = "flashinfer_mxfp4"
self.enable_triton_kernel_moe = False
logger.warning( logger.warning(
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel." "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
) )
else: else:
if self.enable_triton_kernel_moe: if self.moe_runner_backend == "triton_kernel":
assert ( assert (
self.ep_size == 1 self.ep_size == 1
), "Triton kernel MoE is only supported when ep_size == 1" ), "Triton kernel MoE is only supported when ep_size == 1"
if not self.enable_triton_kernel_moe and self.ep_size == 1: if (
self.enable_triton_kernel_moe = True self.moe_runner_backend == "auto"
and self.ep_size == 1
and is_triton_kernels_available()
):
self.moe_runner_backend = "triton_kernel"
logger.warning( logger.warning(
"Detected GPT-OSS model, enabling triton_kernels MOE kernel." "Detected GPT-OSS model, enabling triton_kernels MOE kernel."
) )

View File

@@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import (
CommunicateSummableTensorPairFn, CommunicateSummableTensorPairFn,
ScatterMode, ScatterMode,
) )
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_tbo_token_distribution_threshold,
is_tbo_enabled,
)
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
@@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens) vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
left_sum = sum(extend_lens[:vanilla_split_seq_index]) left_sum = sum(extend_lens[:vanilla_split_seq_index])
overall_sum = sum(extend_lens) overall_sum = sum(extend_lens)
threshold = global_server_args_dict["tbo_token_distribution_threshold"] threshold = get_tbo_token_distribution_threshold()
assert threshold <= 0.5, f"{threshold=}" assert threshold <= 0.5, f"{threshold=}"
return left_sum < overall_sum * threshold or left_sum > overall_sum * ( return left_sum < overall_sum * threshold or left_sum > overall_sum * (
1 - threshold 1 - threshold
@@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin:
self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32) self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32)
def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int):
if not global_server_args_dict["enable_two_batch_overlap"]: if not is_tbo_enabled():
return return
token_num_per_seq = get_token_num_per_seq( token_num_per_seq = get_token_num_per_seq(
forward_mode=batch.forward_mode, spec_info=batch.spec_info forward_mode=batch.forward_mode, spec_info=batch.spec_info
@@ -353,10 +358,12 @@ class TboDPAttentionPreparer:
def prepare_all_gather( def prepare_all_gather(
self, self,
local_batch: ScheduleBatch, local_batch: ScheduleBatch,
deepep_mode: DeepEPMode,
enable_deepep_moe: bool,
enable_two_batch_overlap: bool,
): ):
deepep_mode = get_deepep_mode()
enable_deepep_moe = get_moe_a2a_backend().is_deepep()
enable_two_batch_overlap = is_tbo_enabled()
self.enable_two_batch_overlap = enable_two_batch_overlap self.enable_two_batch_overlap = enable_two_batch_overlap
if local_batch is not None: if local_batch is not None:
@@ -384,7 +391,7 @@ class TboDPAttentionPreparer:
and not local_batch.forward_mode.is_target_verify() and not local_batch.forward_mode.is_target_verify()
) )
and enable_deepep_moe and enable_deepep_moe
and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY) and (resolved_deepep_mode.is_low_latency())
) )
else: else:
self.local_tbo_split_seq_index = 0 self.local_tbo_split_seq_index = 0
@@ -657,6 +664,7 @@ class TboForwardBatchPreparer:
"req_to_token_pool", "req_to_token_pool",
"token_to_kv_pool", "token_to_kv_pool",
"can_run_dp_cuda_graph", "can_run_dp_cuda_graph",
"dp_padding_mode",
"global_forward_mode", "global_forward_mode",
"spec_algorithm", "spec_algorithm",
"capture_hidden_mode", "capture_hidden_mode",
@@ -701,7 +709,6 @@ class TboForwardBatchPreparer:
tbo_children=None, tbo_children=None,
global_num_tokens_gpu=None, global_num_tokens_gpu=None,
global_num_tokens_cpu=None, global_num_tokens_cpu=None,
dp_padding_mode=None,
global_dp_buffer_len=global_dp_buffer_len, global_dp_buffer_len=global_dp_buffer_len,
global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None, global_num_tokens_for_logprob_cpu=None,
@@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b):
class MaybeTboDeepEPDispatcher: class MaybeTboDeepEPDispatcher:
def __init__(self, **kwargs): def __init__(self, **kwargs):
num_inner_dispatchers = ( num_inner_dispatchers = 2 if is_tbo_enabled() else 1
2 if global_server_args_dict["enable_two_batch_overlap"] else 1
)
self._inners = [ self._inners = [
DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers)
] ]

View File

@@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args):
return True return True
elif not server_args.enable_dp_lm_head: elif not server_args.enable_dp_lm_head:
return True return True
elif server_args.moe_a2a_backend is None: elif server_args.moe_a2a_backend == "none":
return True return True
else: else:
return ( return (
@@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args):
Check if the input of attention is scattered. Check if the input of attention is scattered.
""" """
assert server_args.moe_dense_tp_size in [1, None] assert server_args.moe_dense_tp_size in [1, None]
if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1: if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention: if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size return server_args.dp_size < server_args.tp_size
else: else:

View File

@@ -6,7 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
per_tensor_quant_mla_fp8, per_tensor_quant_mla_fp8,
per_token_group_quant_fp8, per_token_group_quant_fp8,
@@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
score = torch.randn((M, E), dtype=dtype) score = torch.randn((M, E), dtype=dtype)
with torch.inference_mode(): with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
renormalize=False,
) )
out = fused_moe( out = fused_moe(
a, a,
@@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
w2_scale=w2_s, w2_scale=w2_s,
block_shape=block_size, block_shape=block_size,
) )
ref_out = torch_w8a8_block_fp8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
self.assertTrue( self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))

View File

@@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
run_moe_ep_preproess, run_moe_ep_preproess,
silu_and_mul_triton_kernel, silu_and_mul_triton_kernel,
) )
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
@@ -22,35 +22,26 @@ def ep_moe(
w1: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, w2: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, topk_config: TopKConfig,
renormalize: bool,
# ep config # ep config
num_experts: int = 256, num_experts: int = 256,
fp8_dtype: torch.types = torch.float8_e4m3fn, fp8_dtype: torch.types = torch.float8_e4m3fn,
num_experts_per_partition: int = 128, num_experts_per_partition: int = 128,
start_expert_id: int = 0, start_expert_id: int = 0,
end_expert_id: int = 127, end_expert_id: int = 127,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
w1_scale_inv: Optional[torch.Tensor] = None, w1_scale_inv: Optional[torch.Tensor] = None,
w2_scale_inv: Optional[torch.Tensor] = None, w2_scale_inv: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
use_blockwise_fp8 = block_shape is not None use_blockwise_fp8 = block_shape is not None
topk_weights, topk_ids, _ = select_experts( top_k = topk_config.top_k
topk_output = select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
top_k=top_k, topk_config=topk_config,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
# correction_bias=correction_bias, #skip this in test
custom_routing_function=custom_routing_function,
) )
topk_weights, topk_ids, _ = topk_output
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts) reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts)
@@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
start_id = cur_rank * num_experts_per_partition start_id = cur_rank * num_experts_per_partition
end_id = start_id + num_experts_per_partition - 1 end_id = start_id + num_experts_per_partition - 1
topk_config = TopKConfig(
top_k=topk,
renormalize=False,
)
with torch.inference_mode(): with torch.inference_mode():
out = ep_moe( out = ep_moe(
hidden_states=a, hidden_states=a,
w1=w1, w1=w1,
w2=w2, w2=w2,
router_logits=score, router_logits=score,
top_k=topk, topk_config=topk_config,
renormalize=False,
use_fp8_w8a8=True, use_fp8_w8a8=True,
w1_scale_inv=w1_s, w1_scale_inv=w1_s,
w2_scale_inv=w2_s, w2_scale_inv=w2_s,
@@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase):
w1=w1_ref, w1=w1_ref,
w2=w2_ref, w2=w2_ref,
router_logits=score, router_logits=score,
top_k=topk, topk_config=topk_config,
renormalize=False,
use_fp8_w8a8=False, use_fp8_w8a8=False,
w1_scale_inv=None, w1_scale_inv=None,
w2_scale_inv=None, w2_scale_inv=None,

View File

@@ -6,7 +6,7 @@ import pytest
import torch import torch
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor: def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor:
@@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
s_strides2 = c_strides2 s_strides2 = c_strides2
score = torch.randn((M, E), dtype=dtype, device=device) score = torch.randn((M, E), dtype=dtype, device=device)
topk_weights, topk_ids, _ = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
topk_weights, topk_ids, _ = topk_output
expert_map = torch.arange(E, dtype=torch.int32, device=device) expert_map = torch.arange(E, dtype=torch.int32, device=device)
expert_map[local_e:] = E expert_map[local_e:] = E

View File

@@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
if torch.cuda.get_device_capability() < (10, 0): if torch.cuda.get_device_capability() < (10, 0):
pytest.skip( pytest.skip(
@@ -163,11 +163,12 @@ def check_moe(
score = torch.randn((m, e), device="cuda", dtype=dtype) score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids, _ = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
topk_weights, topk_ids, _ = topk_output
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32)

View File

@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
@@ -175,10 +175,13 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
with torch.inference_mode(): with torch.inference_mode():
ref_out = torch_w8a8_block_int8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
out = fused_moe( out = fused_moe(
a, a,
w1, w1,
@@ -189,9 +192,6 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase):
w2_scale=w2_s, w2_scale=w2_s,
block_shape=block_size, block_shape=block_size,
) )
ref_out = torch_w8a8_block_int8_moe(
a, w1, w2, w1_s, w2_s, score, topk, block_size
)
self.assertTrue( self.assertTrue(
torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32)))

View File

@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
@@ -118,7 +118,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
out = fused_moe( out = fused_moe(
a, a,

View File

@@ -6,7 +6,7 @@ from tqdm import tqdm
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
@@ -136,19 +136,7 @@ class TestFusedMOE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
)
sglang_output = fused_moe(
a,
w1,
w2,
topk_output,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
) )
torch_output = self.torch_naive_moe( torch_output = self.torch_naive_moe(
@@ -162,6 +150,18 @@ class TestFusedMOE(CustomTestCase):
a1_scale, a1_scale,
a2_scale, a2_scale,
) )
sglang_output = fused_moe(
a,
w1,
w2,
topk_output,
use_fp8_w8a8=True,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
torch.testing.assert_close( torch.testing.assert_close(
sglang_output, torch_output, rtol=rtol, atol=atol sglang_output, torch_output, rtol=rtol, atol=atol
) )
@@ -174,7 +174,7 @@ class TestFusedMOE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
triton_output = fused_moe(a, w1, w2, topk_output) triton_output = fused_moe(a, w1, w2, topk_output)

View File

@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
@@ -130,7 +130,7 @@ class TestW8A8FP8FusedMoE(CustomTestCase):
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk, renormalize=False),
) )
out = fused_moe( out = fused_moe(
a, a,

View File

@@ -5,7 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import TopKConfig, select_experts
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6] TOP_KS = [2, 6]
@@ -223,7 +223,7 @@ def test_fused_moe_wn16(
topk_output = select_experts( topk_output = select_experts(
hidden_states=a, hidden_states=a,
router_logits=score, router_logits=score,
top_k=topk, topk_config=TopKConfig(top_k=topk),
) )
triton_output = fused_moe( triton_output = fused_moe(