diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 2af320d56..09caf9e9e 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -11,6 +11,7 @@ import triton from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig +from sglang.srt.layers.moe.fused_moe_triton import override_config from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( fused_moe, get_config_dtype_str, @@ -18,7 +19,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( get_default_config, get_moe_configs, ) -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.utils import is_hip _is_hip = is_hip() @@ -117,17 +119,23 @@ def benchmark_config( w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) - topk_output = select_experts(x, input_gating, topk, renormalize=True) + topk_config = TopKConfig( + top_k=topk, + renormalize=True, + ) + topk_output = select_experts(x, input_gating, topk_config) def prepare(i: int): input_gating = gating_output[i] - new_topk_output = select_experts(x, input_gating, topk, renormalize=True) + new_topk_output = select_experts(x, input_gating, topk_config) topk_output.topk_weights.copy_(new_topk_output.topk_weights) topk_output.topk_ids.copy_(new_topk_output.topk_ids) topk_output.router_logits.copy_(new_topk_output.router_logits) def run(): - from sglang.srt.layers.moe.fused_moe_triton import override_config + moe_runner_config = MoeRunnerConfig( + inplace=True, + ) with override_config(config): fused_moe( @@ -135,7 +143,7 @@ def benchmark_config( w1, w2, topk_output, - inplace=True, + moe_runner_config=moe_runner_config, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 3ce9ad469..c63b8a604 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -213,12 +213,11 @@ Please consult the documentation below and [server_args.py](https://github.com/s | Arguments | Description | Defaults | |-----------|-------------|----------| | `--ep-size` | The expert parallelism size. | 1 | -| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | None | -| `--enable-flashinfer-cutlass-moe` | Enabling Flashinfer Cutlass MoE implementation for high throughput. | False | -| `--enable-flashinfer-trtllm-moe` | Enabling Flashinfer Trtllm MoE implementation for low latency. | False | +| `--moe-a2a-backend` | Select the backend for all-to-all communication for expert parallelism. | none | +| `--moe-runner-backend` | Select the runner backend for MoE. | 'triton' | | `--deepep-mode` | Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch. | auto | | `--ep-num-redundant-experts` | Allocate this number of redundant experts in expert parallel. | 0 | -| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in expert parallel. | None | +| `--ep-dispatch-algorithm` | The algorithm to choose ranks for redundant experts in EPLB. | None | | `--init-expert-location` | Initial location of EP experts. | trivial | | `--enable-eplb` | Enable EPLB algorithm. | False | | `--eplb-algorithm` | Chosen EPLB algorithm. | auto | @@ -280,7 +279,6 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--disable-chunked-prefix-cache` | Disable chunked prefix cache. | False | | `--disable-fast-image-processor` | Disable fast image processor. | False | | `--enable-return-hidden-states` | Enable returning hidden states. | False | -| `--enable-triton-kernel-moe` | Enable Triton kernel for MoE. | False | ## Debug tensor dumps diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 8401e4708..aa43bb027 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -61,7 +61,6 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed.parallel_state import destroy_distributed_environment from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.scheduler import Scheduler from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -300,11 +299,6 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): disable_cuda_graph=model_runner.server_args.disable_cuda_graph, spec_algorithm=SpeculativeAlgorithm.NONE, speculative_num_draft_tokens=None, - enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap, - enable_deepep_moe=MoeA2ABackend( - model_runner.server_args.moe_a2a_backend - ).is_deepep(), - deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode), require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, ) diff --git a/python/sglang/srt/eplb/expert_distribution.py b/python/sglang/srt/eplb/expert_distribution.py index c954394e6..c4a2c38f9 100644 --- a/python/sglang/srt/eplb/expert_distribution.py +++ b/python/sglang/srt/eplb/expert_distribution.py @@ -25,7 +25,6 @@ import torch import torch.distributed from sglang.srt.eplb.expert_location import ExpertLocationMetadata -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.server_args import ServerArgs from sglang.srt.utils import Withable, get_bool_env_var @@ -288,14 +287,14 @@ class _SinglePassGatherer(ABC): ) if server_args.expert_distribution_recorder_mode == "stat_approx": - if server_args.moe_a2a_backend is not None and ( + if server_args.moe_a2a_backend != "none" and ( server_args.deepep_mode == "normal" ): return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank) else: raise NotImplementedError - if server_args.moe_a2a_backend is not None: + if server_args.moe_a2a_backend != "none": if server_args.deepep_mode == "normal": return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank) elif server_args.deepep_mode == "low_latency": diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 624b06017..27a1721aa 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -17,7 +17,7 @@ from enum import Enum, auto from functools import partial from typing import Dict, Optional -import torch.distributed +import torch from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, @@ -35,6 +35,7 @@ from sglang.srt.layers.dp_attention import ( get_global_dp_buffer, get_local_dp_buffer, ) +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.utils import is_sm100_supported from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch @@ -111,7 +112,7 @@ class LayerScatterModes: if context.is_layer_sparse: return ( ScatterMode.SCATTERED - if not global_server_args_dict["moe_a2a_backend"].is_standard() + if not get_moe_a2a_backend().is_none() else ScatterMode.FULL ) else: diff --git a/python/sglang/srt/layers/moe/__init__.py b/python/sglang/srt/layers/moe/__init__.py new file mode 100644 index 000000000..88bdb5787 --- /dev/null +++ b/python/sglang/srt/layers/moe/__init__.py @@ -0,0 +1,29 @@ +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.utils import ( + DeepEPMode, + MoeA2ABackend, + MoeRunnerBackend, + get_deepep_config, + get_deepep_mode, + get_moe_a2a_backend, + get_moe_runner_backend, + get_tbo_token_distribution_threshold, + initialize_moe_config, + is_tbo_enabled, + should_use_flashinfer_trtllm_moe, +) + +__all__ = [ + "DeepEPMode", + "MoeA2ABackend", + "MoeRunnerConfig", + "MoeRunnerBackend", + "initialize_moe_config", + "get_moe_a2a_backend", + "get_moe_runner_backend", + "get_deepep_mode", + "should_use_flashinfer_trtllm_moe", + "is_tbo_enabled", + "get_tbo_token_distribution_threshold", + "get_deepep_config", +] diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 8e99d212d..32684c606 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,11 +1,17 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional, Union import torch from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size +from sglang.srt.layers.moe import ( + get_deepep_mode, + get_moe_a2a_backend, + get_moe_runner_backend, + should_use_flashinfer_trtllm_moe, +) from sglang.srt.layers.moe.ep_moe.kernels import ( ep_gather, ep_scatter, @@ -16,14 +22,9 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( ) from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE from sglang.srt.layers.moe.topk import TopKOutput -from sglang.srt.layers.moe.utils import DeepEPMode, should_use_flashinfer_trtllm_moe from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.layers.quantization.fp8 import ( - Fp8Config, - Fp8MoEMethod, - get_tile_tokens_dim, -) +from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, sglang_per_token_group_quant_fp8, @@ -89,12 +90,11 @@ class EPMoE(FusedMoE): num_fused_shared_experts: int = 0, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, prefix: str = "", activation: str = "silu", routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_clamp_limit: Optional[float] = None, with_bias: bool = False, ): super().__init__( @@ -106,13 +106,12 @@ class EPMoE(FusedMoE): top_k=top_k, params_dtype=params_dtype, quant_config=quant_config, - tp_size=tp_size, prefix=prefix, activation=activation, # apply_router_weight_on_input=apply_router_weight_on_input, routed_scaling_factor=routed_scaling_factor, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, + gemm1_alpha=gemm1_alpha, + gemm1_clamp_limit=gemm1_clamp_limit, with_bias=with_bias, ) @@ -163,7 +162,8 @@ class EPMoE(FusedMoE): ) assert self.quant_method is not None - assert self.activation == "silu" + assert self.moe_runner_config.activation == "silu" + hidden_states_shape = hidden_states.shape hidden_states_dtype = hidden_states.dtype hidden_states_device = hidden_states.device @@ -327,8 +327,8 @@ class EPMoE(FusedMoE): m_max * self.start_expert_id, BLOCK_SIZE=512, ) - if self.routed_scaling_factor is not None: - output *= self.routed_scaling_factor + if self.moe_runner_config.routed_scaling_factor is not None: + output *= self.moe_runner_config.routed_scaling_factor return output @@ -349,11 +349,9 @@ class DeepEPMoE(EPMoE): num_fused_shared_experts: int = 0, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, prefix: str = "", activation: str = "silu", routed_scaling_factor: Optional[float] = None, - deepep_mode: DeepEPMode = DeepEPMode.AUTO, ): super().__init__( num_experts=num_experts, @@ -364,12 +362,11 @@ class DeepEPMoE(EPMoE): num_fused_shared_experts=num_fused_shared_experts, params_dtype=params_dtype, quant_config=quant_config, - tp_size=tp_size, prefix=prefix, activation=activation, routed_scaling_factor=routed_scaling_factor, ) - self.deepep_mode = deepep_mode + self.deepep_mode = get_deepep_mode() # TODO: move to the beginning of the file from sglang.srt.distributed.parallel_state import get_tp_group @@ -383,7 +380,7 @@ class DeepEPMoE(EPMoE): num_local_experts=self.num_local_experts, hidden_size=hidden_size, params_dtype=params_dtype, - deepep_mode=deepep_mode, + deepep_mode=self.deepep_mode, async_finish=True, # TODO return_recv_hook=True, ) @@ -458,15 +455,19 @@ class DeepEPMoE(EPMoE): ) def moe_impl(self, dispatch_output: DispatchOutput): + from sglang.srt.layers.moe.token_dispatcher import DispatchOutputChecker + if _use_aiter: + assert DispatchOutputChecker.format_is_deepep(dispatch_output) # in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel return self.forward_aiter(dispatch_output) if _is_npu: + assert DispatchOutputChecker.format_is_ascent_ll(dispatch_output) return self.forward_npu(dispatch_output) - if dispatch_output.format.is_deepep_normal(): + if DispatchOutputChecker.format_is_deepep_normal(dispatch_output): assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_contiguous(dispatch_output) - elif dispatch_output.format.is_deepep_ll(): + elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output): assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8 return self.forward_deepgemm_masked(dispatch_output) else: @@ -490,7 +491,7 @@ class DeepEPMoE(EPMoE): def forward_aiter( self, - dispatch_output: DeepEPNormalOutput, + dispatch_output: Union[DeepEPNormalOutput, DeepEPLLOutput], ): hidden_states, topk_idx, topk_weights = ( dispatch_output.hidden_states, @@ -516,7 +517,7 @@ class DeepEPMoE(EPMoE): quant_type=QuantType.per_128x128, activation=( ActivationType.Silu - if self.activation == "silu" + if self.moe_runner_config.activation == "silu" else ActivationType.Gelu ), expert_mask=self.expert_mask, @@ -531,7 +532,7 @@ class DeepEPMoE(EPMoE): ) hidden_states_fp8, hidden_states_scale = hidden_states_fp8 assert self.quant_method is not None - assert self.activation == "silu" + assert self.moe_runner_config.activation == "silu" if num_recv_tokens_per_expert is None: return hidden_states_fp8.bfloat16() all_tokens = sum(num_recv_tokens_per_expert) @@ -652,7 +653,7 @@ class DeepEPMoE(EPMoE): ): hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output assert self.quant_method is not None - assert self.activation == "silu" + assert self.moe_runner_config.activation == "silu" # GroupGemm-0 num_groups, m, k = hidden_states_fp8[0].size() @@ -783,12 +784,12 @@ class DeepEPMoE(EPMoE): def get_moe_impl_class(): - if global_server_args_dict["moe_a2a_backend"].is_deepep(): + if get_moe_a2a_backend().is_deepep(): return DeepEPMoE # NEW: Direct FP4 detection (bypasses EP requirements) # Check for FP4 quantization with TRTLLM flag, regardless of EP - if global_server_args_dict.get("enable_flashinfer_trtllm_moe", False): + if get_moe_runner_backend().is_flashinfer_trtllm(): try: # Check the quantization argument directly quantization = global_server_args_dict.get("quantization") @@ -803,7 +804,7 @@ def get_moe_impl_class(): if should_use_flashinfer_trtllm_moe(): return FlashInferFusedMoE - if global_server_args_dict["enable_flashinfer_cutlass_moe"]: + if get_moe_runner_backend().is_flashinfer_cutlass(): return FusedMoE if get_moe_expert_parallel_world_size() > 1: return EPMoE diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 61eacd78c..92b88b1b7 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -3,28 +3,22 @@ Torch-native implementation for FusedMoE. This is used for torch.compile. It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204 """ -from typing import Callable, Optional - import torch from torch.nn import functional as F from sglang.srt.layers.activation import GeluAndMul, SiluAndMul -from sglang.srt.layers.moe.topk import TopKOutput +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import StandardTopKOutput def fused_moe_forward_native( layer: torch.nn.Module, x: torch.Tensor, - topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + topk_output: StandardTopKOutput, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - if apply_router_weight_on_input: + if moe_runner_config.apply_router_weight_on_input: raise NotImplementedError() topk_weights, topk_ids, _ = topk_output @@ -33,12 +27,12 @@ def fused_moe_forward_native( w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) w2_weights = layer.w2_weight[topk_ids] x1 = torch.einsum("ti,taoi -> tao", x, w1_weights) - if activation == "silu": + if moe_runner_config.activation == "silu": x1 = F.silu(x1) - elif activation == "gelu": + elif moe_runner_config.activation == "gelu": x1 = F.gelu(x1) else: - raise ValueError(f"Unsupported activation: {activation=}") + raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}") x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) @@ -47,16 +41,11 @@ def fused_moe_forward_native( def moe_forward_native( layer: torch.nn.Module, x: torch.Tensor, - topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + topk_output: StandardTopKOutput, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - if apply_router_weight_on_input: + if moe_runner_config.apply_router_weight_on_input: raise NotImplementedError() topk_weights, topk_ids, _ = topk_output @@ -72,12 +61,12 @@ def moe_forward_native( sorted_tokens = x[idxs // topk_ids.shape[1]] tokens_per_expert = tokens_per_expert.cpu().numpy() - if activation == "silu": + if moe_runner_config.activation == "silu": act = SiluAndMul() - elif activation == "gelu": + elif moe_runner_config.activation == "gelu": act = GeluAndMul() else: - raise ValueError(f"Unsupported activation: {activation=}") + raise ValueError(f"Unsupported activation: {moe_runner_config.activation=}") outputs = [] start_idx = 0 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 2cd0099b4..0d89ebc88 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -2,17 +2,20 @@ """Fused MoE kernel.""" +from __future__ import annotations + import functools import json import logging import os -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import torch import triton import triton.language as tl -from sglang.srt.layers.moe.topk import TopKOutput +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, scaled_fp8_quant, @@ -1025,8 +1028,8 @@ def inplace_fused_experts( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_limit: Optional[float] = None, ) -> None: fused_experts_impl( hidden_states, @@ -1053,8 +1056,8 @@ def inplace_fused_experts( block_shape, False, routed_scaling_factor, - activation_alpha, - swiglu_limit, + gemm1_alpha, + gemm1_limit, ) @@ -1081,8 +1084,8 @@ def inplace_fused_experts_fake( a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_limit: Optional[float] = None, ) -> None: pass @@ -1119,8 +1122,8 @@ def outplace_fused_experts( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_limit: Optional[float] = None, ) -> torch.Tensor: return fused_experts_impl( hidden_states, @@ -1147,8 +1150,8 @@ def outplace_fused_experts( block_shape, no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, + gemm1_alpha=gemm1_alpha, + gemm1_limit=gemm1_limit, ) @@ -1176,8 +1179,8 @@ def outplace_fused_experts_fake( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_limit: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -1194,12 +1197,10 @@ def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_output: TopKOutput, + topk_output: StandardTopKOutput, + moe_runner_config: MoeRunnerConfig, b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1212,14 +1213,10 @@ def fused_experts( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, ): topk_weights, topk_ids, _ = topk_output - if inplace: - assert not no_combine, "no combine + inplace makes no sense" + if moe_runner_config.inplace: + assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( hidden_states, w1, @@ -1228,8 +1225,8 @@ def fused_experts( topk_ids, b1, b2, - activation, - apply_router_weight_on_input, + moe_runner_config.activation, + moe_runner_config.apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1242,9 +1239,9 @@ def fused_experts( a1_scale, a2_scale, block_shape, - routed_scaling_factor, - activation_alpha, - swiglu_limit, + moe_runner_config.routed_scaling_factor, + moe_runner_config.gemm1_alpha, + moe_runner_config.gemm1_clamp_limit, ) return hidden_states else: @@ -1256,8 +1253,8 @@ def fused_experts( topk_ids, b1, b2, - activation, - apply_router_weight_on_input, + moe_runner_config.activation, + moe_runner_config.apply_router_weight_on_input, use_fp8_w8a8, use_int8_w8a8, use_int8_w8a16, @@ -1270,10 +1267,10 @@ def fused_experts( a1_scale, a2_scale, block_shape, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, + no_combine=moe_runner_config.no_combine, + routed_scaling_factor=moe_runner_config.routed_scaling_factor, + gemm1_alpha=moe_runner_config.gemm1_alpha, + gemm1_limit=moe_runner_config.gemm1_clamp_limit, ) @@ -1370,11 +1367,11 @@ def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor): @torch.compile -def swiglu_with_alpha_and_limit(x, alpha, limit): +def swiglu_with_alpha_and_limit(x, gemm1_alpha, gemm1_limit): gate, up = x[..., ::2], x[..., 1::2] - gate = gate.clamp(min=None, max=limit) - up = up.clamp(min=-limit, max=limit) - return gate * torch.sigmoid(gate * alpha) * (up + 1) + gate = gate.clamp(min=None, max=gemm1_limit) + up = up.clamp(min=-gemm1_limit, max=gemm1_limit) + return gate * torch.sigmoid(gate * gemm1_alpha) * (up + 1) def fused_experts_impl( @@ -1402,8 +1399,8 @@ def fused_experts_impl( block_shape: Optional[List[int]] = None, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_limit: Optional[float] = None, ): padded_size = padding_size if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: @@ -1533,12 +1530,12 @@ def fused_experts_impl( block_shape=block_shape, ) if activation == "silu": - if activation_alpha is not None: - assert swiglu_limit is not None + if gemm1_alpha is not None: + assert gemm1_limit is not None intermediate_cache2 = swiglu_with_alpha_and_limit( intermediate_cache1.view(-1, N), - activation_alpha, - swiglu_limit, + gemm1_alpha, + gemm1_limit, ) elif _is_cuda: silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) @@ -1547,10 +1544,8 @@ def fused_experts_impl( intermediate_cache2, intermediate_cache1.view(-1, N) ) elif activation == "gelu": - assert ( - activation_alpha is None - ), "activation_alpha is not supported for gelu" - assert swiglu_limit is None, "swiglu_limit is not supported for gelu" + assert gemm1_alpha is None, "gemm1_alpha is not supported for gelu" + assert gemm1_limit is None, "gemm1_limit is not supported for gelu" if _is_cuda: gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2) else: @@ -1641,12 +1636,10 @@ def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_output: TopKOutput, + topk_output: StandardTopKOutput, + moe_runner_config: MoeRunnerConfig = MoeRunnerConfig(), b1: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None, - inplace: bool = False, - activation: str = "silu", - apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1659,10 +1652,6 @@ def fused_moe( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -1672,11 +1661,10 @@ def fused_moe( - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - - topk_output (TopKOutput): The top-k output of the experts. + - topk_output (StandardTopKOutput): The top-k output of the experts. + - moe_runner_config (MoeRunnerConfig): The configuration for the MoE runner. - b1 (Optional[torch.Tensor]): Optional bias for w1. - b2 (Optional[torch.Tensor]): Optional bias for w2. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner @@ -1696,9 +1684,9 @@ def fused_moe( a2. - block_shape: (Optional[List[int]]): Optional block size for block-wise quantization. - - activation_alpha (Optional[float]): Optional alpha for the activation + - gemm1_alpha (Optional[float]): Optional gemm1_alpha for the activation function. - - swiglu_limit (Optional[float]): Optional limit for the swiglu activation + - gemm1_limit (Optional[float]): Optional gemm1_limit for the swiglu activation function. Returns: @@ -1710,11 +1698,9 @@ def fused_moe( w1, w2, topk_output, + moe_runner_config=moe_runner_config, b1=b1, b2=b2, - inplace=inplace, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, @@ -1727,8 +1713,4 @@ def fused_moe( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 990d88aed..46473ac4c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,10 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py -import datetime -import glob import logging -import os -import sys from enum import Enum from typing import List, Optional, Tuple @@ -22,8 +18,12 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( use_symmetric_memory, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata -from sglang.srt.layers.moe.topk import StandardTopKOutput -from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe +from sglang.srt.layers.moe import ( + MoeRunnerConfig, + get_moe_runner_backend, + should_use_flashinfer_trtllm_moe, +) +from sglang.srt.layers.moe.topk import TopKOutput, TopKOutputChecker from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -126,7 +126,6 @@ class FusedMoE(torch.nn.Module): params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, quant_config: Optional[QuantizationConfig] = None, - tp_size: Optional[int] = None, prefix: str = "", activation: str = "silu", apply_router_weight_on_input: bool = False, @@ -134,9 +133,8 @@ class FusedMoE(torch.nn.Module): inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, - enable_flashinfer_cutlass_moe: Optional[bool] = False, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + gemm1_alpha: Optional[float] = None, + gemm1_clamp_limit: Optional[float] = None, use_weight_loader_fused: bool = False, with_bias=False, ): @@ -153,9 +151,17 @@ class FusedMoE(torch.nn.Module): self.expert_map_cpu = None self.expert_map_gpu = None - # For activation - self.activation_alpha = activation_alpha - self.swiglu_limit = swiglu_limit + self.moe_runner_config = MoeRunnerConfig( + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, + gemm1_alpha=gemm1_alpha, + gemm1_clamp_limit=gemm1_clamp_limit, + ) + + enable_flashinfer_cutlass_moe = get_moe_runner_backend().is_flashinfer_cutlass() if enable_flashinfer_cutlass_moe and quant_config is None: logger.warning("Disable flashinfer MoE when quantization config is None.") @@ -184,20 +190,12 @@ class FusedMoE(torch.nn.Module): * self.num_local_experts ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu") - self.routed_scaling_factor = routed_scaling_factor assert intermediate_size % self.moe_tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.moe_tp_size self.reduce_results = reduce_results - self.activation = activation - self.apply_router_weight_on_input = apply_router_weight_on_input self.use_presharded_weights = use_presharded_weights - self.inplace = inplace - self.no_combine = no_combine - - self.use_triton_kernels = ( - not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"] - ) + self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod( self.use_triton_kernels @@ -207,14 +205,12 @@ class FusedMoE(torch.nn.Module): assert self.quant_method is not None self.quant_config = quant_config - self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get( - "enable_flashinfer_mxfp4_moe", False - ) + self.use_flashinfer_mxfp4_moe = get_moe_runner_backend().is_flashinfer_mxfp4() # TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic if ( self.quant_config is not None and self.quant_config.get_name() == "mxfp4" - and self.use_enable_flashinfer_mxfp4_moe + and self.use_flashinfer_mxfp4_moe ): hidden_size = round_up(hidden_size, 256) self.quant_method.create_weights( @@ -794,7 +790,7 @@ class FusedMoE(torch.nn.Module): f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded." ) - def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): origin_hidden_states_dim = hidden_states.shape[-1] assert self.quant_method is not None @@ -803,40 +799,22 @@ class FusedMoE(torch.nn.Module): # If we are in EP mode, we need to move the expert map to GPU. self.expert_map_gpu = self.expert_map_cpu.to(device="cuda") - if self.expert_map_gpu is not None and isinstance( - topk_output, StandardTopKOutput - ): - topk_output = topk_output._replace( - topk_ids=self.expert_map_gpu[topk_output.topk_ids] - ) + if self.expert_map_gpu is not None: + if TopKOutputChecker.format_is_standard(topk_output): + topk_output = topk_output._replace( + topk_ids=self.expert_map_gpu[topk_output.topk_ids] + ) + elif TopKOutputChecker.format_is_triton_kernel(topk_output): + raise NotImplementedError() # Matrix multiply. with use_symmetric_memory(get_tp_group()) as sm: - kwargs = {} - if self.activation_alpha is not None: - kwargs["activation_alpha"] = self.activation_alpha - if self.swiglu_limit is not None: - kwargs["swiglu_limit"] = self.swiglu_limit final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, topk_output=topk_output, - activation=self.activation, - apply_router_weight_on_input=self.apply_router_weight_on_input, - routed_scaling_factor=self.routed_scaling_factor, - **( - dict( - tp_rank=self.moe_tp_rank, - tp_size=self.moe_tp_size, - ep_rank=self.moe_ep_rank, - ep_size=self.moe_ep_size, - ) - if self.quant_method.__class__.__name__ - == "ModelOptNvFp4FusedMoEMethod" - else {} - ), - **kwargs, + moe_runner_config=self.moe_runner_config, ) sm.tag(final_hidden_states) @@ -944,24 +922,10 @@ class FusedMoE(torch.nn.Module): class FlashInferFusedMoE(FusedMoE): def __init__(self, *args, **kwargs): - renormalize = kwargs.pop("renormalize", True) - num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0) - use_grouped_topk = kwargs.pop("use_grouped_topk", False) - num_expert_group = kwargs.pop("num_expert_group", None) - topk_group = kwargs.pop("topk_group", None) - correction_bias = kwargs.pop("correction_bias", None) super().__init__(*args, **kwargs) - self.renormalize = renormalize - self.num_fused_shared_experts = num_fused_shared_experts - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.correction_bias = correction_bias self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe() - def forward(self, hidden_states: torch.Tensor, topk_output: tuple): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert self.use_flashinfer_trtllm_moe assert ( self.activation == "silu" @@ -974,20 +938,14 @@ class FlashInferFusedMoE(FusedMoE): self.num_fused_shared_experts == 0 ), "Fused shared experts are not supported for flashinfer blockscale fp8 moe" - # TRTLLM mode expects (TopK_config, router_logits) tuple - if not isinstance(topk_output, tuple) or len(topk_output) != 2: - raise ValueError( - f"FlashInferFusedMoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}" - ) - _, router_logits = topk_output + assert TopKOutputChecker.format_is_bypassed(topk_output) # Matrix multiply. final_hidden_states = self.quant_method.apply_with_router_logits( layer=self, x=hidden_states, - router_logits=router_logits, - activation=self.activation, - routed_scaling_factor=self.routed_scaling_factor, + topk_output=topk_output, + moe_runner_config=self.moe_runner_config, ) if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1): @@ -1000,28 +958,8 @@ class FlashInferFP4MoE(FusedMoE): """FP4 TRTLLM MoE implementation using FlashInfer.""" def __init__(self, *args, **kwargs): - # Extract DeepSeek-specific parameters - renormalize = kwargs.pop("renormalize", True) - num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0) - use_grouped_topk = kwargs.pop("use_grouped_topk", False) - num_expert_group = kwargs.pop("num_expert_group", None) - topk_group = kwargs.pop("topk_group", None) - correction_bias = kwargs.pop("correction_bias", None) - - # Extract additional TopK parameters that were previously extracted in forward - routed_scaling_factor = kwargs.pop("routed_scaling_factor", None) - super().__init__(*args, **kwargs) - # Store DeepSeek parameters - self.renormalize = renormalize - self.num_fused_shared_experts = num_fused_shared_experts - self.use_grouped_topk = use_grouped_topk - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.correction_bias = correction_bias - self.routed_scaling_factor = routed_scaling_factor - # --------------------------------------------------------------------- # Helper: quantize hidden states to FP4 each forward pass # --------------------------------------------------------------------- @@ -1052,21 +990,17 @@ class FlashInferFP4MoE(FusedMoE): return hs_fp4, hs_sf - def forward(self, hidden_states: torch.Tensor, topk_output): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): """Forward pass using FP4 TRTLLM kernel. Args: hidden_states: Input tensor - topk_output: Should be tuple of (TopK_config, router_logits) for TRTLLM mode + topk_output: TopKOutput object with Bypassed format """ + assert TopKOutputChecker.format_is_bypassed(topk_output) - # TRTLLM mode expects (TopK_config, router_logits) tuple - if not isinstance(topk_output, tuple) or len(topk_output) != 2: - raise ValueError( - f"FlashInferFP4MoE expects (TopK_config, router_logits) tuple, got {type(topk_output)}" - ) - - _, router_logits = topk_output + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config hs_fp4, hs_scale_linear = self._quantize_hidden_states_fp4(hidden_states) @@ -1074,7 +1008,7 @@ class FlashInferFP4MoE(FusedMoE): result = trtllm_fp4_block_scale_moe( routing_logits=router_logits, - routing_bias=self.correction_bias.to(hidden_states.dtype), + routing_bias=topk_config.correction_bias.to(hidden_states.dtype), hidden_states=hs_fp4, hidden_states_scale=hs_scale_linear.view(torch.float8_e4m3fn).flatten(), gemm1_weights=self.gemm1_weights_fp4_shuffled.data, @@ -1094,15 +1028,15 @@ class FlashInferFP4MoE(FusedMoE): output1_scale_gate_scalar=self.g1_alphas.data, output2_scale_scalar=self.g2_alphas.data, num_experts=self.num_experts, - top_k=self.top_k, - n_group=self.num_expert_group, - topk_group=self.topk_group, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, intermediate_size=self.intermediate_size_per_partition, local_expert_offset=self.moe_ep_rank * self.num_local_experts, local_num_experts=self.num_local_experts, - routed_scaling_factor=self.routed_scaling_factor, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, tile_tokens_dim=_get_tile_tokens_dim( - hidden_states.shape[0], self.top_k, self.num_local_experts + hidden_states.shape[0], topk_config.top_k, self.num_local_experts ), routing_method_type=RoutingMethodType.DeepSeekV3, do_finalize=True, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py index e99dc683a..5d39b8bbc 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py @@ -18,6 +18,7 @@ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx from triton_kernels.swiglu import swiglu_fn if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput @@ -55,8 +56,7 @@ def triton_kernel_moe_forward( w1: torch.Tensor, w2: torch.Tensor, topk_output: TopKOutput, - inplace: bool = False, - activation: str = "silu", + moe_runner_config: MoeRunnerConfig, apply_router_weight_on_input: bool = False, use_fp8_w8a8: bool = False, per_channel_quant: bool = False, @@ -69,7 +69,10 @@ def triton_kernel_moe_forward( block_shape: Optional[list[int]] = None, ) -> torch.Tensor: - assert topk_output.format.is_triton_kernel() + from sglang.srt.layers.moe.topk import TopKOutputChecker + + assert TopKOutputChecker.format_is_triton_kernel(topk_output) + routing_data, gather_idx, scatter_idx = topk_output return triton_kernel_fused_experts( @@ -79,8 +82,8 @@ def triton_kernel_moe_forward( routing_data, gather_idx, scatter_idx, - inplace=inplace, - activation=activation, + inplace=False, # triton kernel doesn't support inplace + activation=moe_runner_config.activation, apply_router_weight_on_input=apply_router_weight_on_input, use_fp8_w8a8=use_fp8_w8a8, per_channel_quant=per_channel_quant, @@ -192,8 +195,7 @@ def triton_kernel_moe_with_bias_forward( w2_pcg, b2: torch.Tensor, topk_output: TopKOutput, - inplace: bool = False, - activation: str = "silu", + moe_runner_config: MoeRunnerConfig, use_fp8_w8a8: bool = False, per_channel_quant: bool = False, global_num_experts: int = -1, @@ -203,10 +205,11 @@ def triton_kernel_moe_with_bias_forward( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[int] = None, ) -> torch.Tensor: - assert topk_output.format.is_triton_kernel() + from sglang.srt.layers.moe.topk import TopKOutputChecker + + assert TopKOutputChecker.format_is_triton_kernel(topk_output) + routing_data, gather_idx, scatter_idx = topk_output return triton_kernel_fused_experts_with_bias( @@ -220,8 +223,8 @@ def triton_kernel_moe_with_bias_forward( routing_data=routing_data, gather_indx=gather_idx, scatter_indx=scatter_idx, - inplace=inplace, - activation=activation, + inplace=False, # triton kernel doesn't support inplace + activation=moe_runner_config.activation, use_fp8_w8a8=use_fp8_w8a8, per_channel_quant=per_channel_quant, global_num_experts=global_num_experts, @@ -231,8 +234,8 @@ def triton_kernel_moe_with_bias_forward( a1_scale=a1_scale, a2_scale=a2_scale, block_shape=block_shape, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, + gemm1_alpha=moe_runner_config.gemm1_alpha, + gemm1_clamp_limit=moe_runner_config.gemm1_clamp_limit, ) @@ -258,10 +261,9 @@ def triton_kernel_fused_experts_with_bias( a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, block_shape: Optional[list[int]] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[int] = None, + gemm1_alpha: Optional[float] = None, + gemm1_clamp_limit: Optional[float] = None, ) -> torch.Tensor: - # print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype) assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported" assert per_channel_quant == False, "per_channel_quant is not supported" assert expert_map == None, "expert_map is not supported" @@ -307,7 +309,7 @@ def triton_kernel_fused_experts_with_bias( act = FusedActivation( FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), - (activation_alpha, swiglu_limit), + (gemm1_alpha, gemm1_clamp_limit), 2, ) diff --git a/python/sglang/srt/layers/moe/moe_runner/__init__.py b/python/sglang/srt/layers/moe/moe_runner/__init__.py new file mode 100644 index 000000000..9a7fa9c29 --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/__init__.py @@ -0,0 +1,3 @@ +from sglang.srt.layers.moe.moe_runner.base import MoeRunnerConfig + +__all__ = ["MoeRunnerConfig"] diff --git a/python/sglang/srt/layers/moe/moe_runner/base.py b/python/sglang/srt/layers/moe/moe_runner/base.py new file mode 100644 index 000000000..854aeb0e6 --- /dev/null +++ b/python/sglang/srt/layers/moe/moe_runner/base.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class MoeRunnerConfig: + activation: str = "silu" + apply_router_weight_on_input: bool = False + inplace: bool = True + no_combine: bool = False + routed_scaling_factor: Optional[float] = None + gemm1_alpha: Optional[float] = None + gemm1_clamp_limit: Optional[float] = None diff --git a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py index 274626424..7802968ac 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/__init__.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/__init__.py @@ -2,20 +2,26 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( BaseDispatcher, BaseDispatcherConfig, DispatchOutput, + DispatchOutputChecker, DispatchOutputFormat, ) from sglang.srt.layers.moe.token_dispatcher.deepep import ( + AscendDeepEPLLOutput, DeepEPConfig, DeepEPDispatcher, DeepEPLLOutput, DeepEPNormalOutput, ) +from sglang.srt.layers.moe.token_dispatcher.standard import StandardDispatchOutput __all__ = [ + "AscendDeepEPLLOutput", "BaseDispatcher", "BaseDispatcherConfig", "DispatchOutput", "DispatchOutputFormat", + "DispatchOutputChecker", + "StandardDispatchOutput", "DeepEPConfig", "DeepEPDispatcher", "DeepEPNormalOutput", diff --git a/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py b/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py index 19661652f..d5ff8cf77 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py @@ -2,35 +2,76 @@ from __future__ import annotations from abc import ABC, abstractmethod from enum import Enum, auto -from typing import Protocol, runtime_checkable +from typing import TYPE_CHECKING, Protocol, TypeGuard, Union, runtime_checkable import torch +if TYPE_CHECKING: + from sglang.srt.layers.moe.token_dispatcher import ( + AscendDeepEPLLOutput, + DeepEPLLOutput, + DeepEPNormalOutput, + StandardDispatchOutput, + ) -class MoEA2ABackend(Enum): - none = "none" - deepep = "deepep" - def is_none(self): - return self == MoEA2ABackend.none +class DispatchOutputChecker: - def is_deepep(self): - return self == MoEA2ABackend.deepep + @staticmethod + def format_is_standard( + dispatch_output: DispatchOutput, + ) -> TypeGuard[StandardDispatchOutput]: + return dispatch_output.format.is_standard() + + @staticmethod + def format_is_deepep_normal( + dispatch_output: DispatchOutput, + ) -> TypeGuard[DeepEPNormalOutput]: + return dispatch_output.format.is_deepep_normal() + + @staticmethod + def format_is_deepep_ll( + dispatch_output: DispatchOutput, + ) -> TypeGuard[DeepEPLLOutput]: + return dispatch_output.format.is_deepep_ll() + + @staticmethod + def format_is_deepep( + dispatch_output: DispatchOutput, + ) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]: + return dispatch_output.format.is_deepep() + + @staticmethod + def format_is_ascent_ll( + dispatch_output: DispatchOutput, + ) -> TypeGuard[AscendDeepEPLLOutput]: + return dispatch_output.format.is_ascent_ll() class DispatchOutputFormat(Enum): - standard = auto() - deepep_normal = auto() - deepep_ll = auto() + + STANDARD = auto() + DEEPEP_NORMAL = auto() + DEEPEP_LL = auto() + ASCENT_LL = auto() def is_standard(self) -> bool: - return self == DispatchOutputFormat.standard + return self == DispatchOutputFormat.STANDARD def is_deepep_normal(self) -> bool: - return self == DispatchOutputFormat.deepep_normal + return self == DispatchOutputFormat.DEEPEP_NORMAL def is_deepep_ll(self) -> bool: - return self == DispatchOutputFormat.deepep_ll + return self == DispatchOutputFormat.DEEPEP_LL + + def is_deepep(self) -> bool: + return self in [ + DispatchOutputFormat.DEEPEP_NORMAL, + DispatchOutputFormat.DEEPEP_LL, + ] + + def is_ascent_ll(self) -> bool: + return self == DispatchOutputFormat.ASCENT_LL @runtime_checkable diff --git a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py index 372717bf9..3e070d814 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/deepep.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/deepep.py @@ -2,27 +2,17 @@ from __future__ import annotations import logging from dataclasses import dataclass -from typing import ( - TYPE_CHECKING, - List, - NamedTuple, - Optional, - Protocol, - Tuple, - Union, - runtime_checkable, -) +from typing import TYPE_CHECKING, List, NamedTuple, Optional, Tuple, Union from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.moe import DeepEPMode, get_deepep_config, is_tbo_enabled from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import ( BaseDispatcher, BaseDispatcherConfig, DispatchOutput, DispatchOutputFormat, ) -from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.quantization import deep_gemm_wrapper -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( get_bool_env_var, get_int_env_var, @@ -72,7 +62,7 @@ class DeepEPNormalOutput(NamedTuple): @property def format(self) -> DispatchOutputFormat: - return DispatchOutputFormat.deepep_normal + return DispatchOutputFormat.DEEPEP_NORMAL class DeepEPLLOutput(NamedTuple): @@ -86,7 +76,7 @@ class DeepEPLLOutput(NamedTuple): @property def format(self) -> DispatchOutputFormat: - return DispatchOutputFormat.deepep_ll + return DispatchOutputFormat.DEEPEP_LL class AscendDeepEPLLOutput(NamedTuple): @@ -101,7 +91,7 @@ class AscendDeepEPLLOutput(NamedTuple): @property def format(self) -> DispatchOutputFormat: - return DispatchOutputFormat.deepep_ll + return DispatchOutputFormat.ASCENT_LL assert isinstance(DeepEPNormalOutput, DispatchOutput) @@ -128,8 +118,8 @@ class DeepEPBuffer: hidden_size: int, param_bytes: int, deepep_mode: DeepEPMode, - num_max_dispatch_tokens_per_rank: int = None, - num_experts: int = None, + num_max_dispatch_tokens_per_rank: int = -1, + num_experts: int = -1, ): if cls._buffer is not None: return cls._buffer @@ -156,8 +146,8 @@ class DeepEPBuffer: num_rdma_bytes, ) if deepep_mode.enable_low_latency(): - assert num_max_dispatch_tokens_per_rank is not None - assert num_experts is not None and num_experts % group.size() == 0 + assert num_max_dispatch_tokens_per_rank != -1 + assert num_experts != -1 and num_experts % group.size() == 0 num_rdma_bytes = max( Buffer.get_low_latency_rdma_size_hint( num_max_dispatch_tokens_per_rank, @@ -181,7 +171,7 @@ class DeepEPBuffer: ).multi_processor_count if ( (deepep_mode != DeepEPMode.LOW_LATENCY) - and not global_server_args_dict["enable_two_batch_overlap"] + and not is_tbo_enabled() and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2) ): logger.warning( @@ -226,7 +216,7 @@ class DeepEPConfig(BaseDispatcherConfig): _instance = None def __init__(self): - config_str = global_server_args_dict["deepep_config"] + config_str = get_deepep_config() if config_str: config_parsed = load_json_config(config_str) if torch.distributed.get_rank() == 0: diff --git a/python/sglang/srt/layers/moe/token_dispatcher/standard.py b/python/sglang/srt/layers/moe/token_dispatcher/standard.py index 4a2d2dd6b..3e09e0bf6 100644 --- a/python/sglang/srt/layers/moe/token_dispatcher/standard.py +++ b/python/sglang/srt/layers/moe/token_dispatcher/standard.py @@ -13,7 +13,7 @@ class StandardDispatchOutput(NamedTuple): @property def format(self) -> DispatchOutputFormat: - return DispatchOutputFormat.standard + return DispatchOutputFormat.STANDARD assert isinstance(StandardDispatchOutput, DispatchOutput) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 8830cd272..3df33898a 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -14,9 +14,18 @@ from __future__ import annotations +import logging import math +from dataclasses import dataclass from enum import Enum, auto -from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable +from typing import ( + Callable, + NamedTuple, + Optional, + Protocol, + TypeGuard, + runtime_checkable, +) import torch import torch.nn.functional as F @@ -28,7 +37,10 @@ from sglang.srt.eplb.expert_location_dispatch import ( ExpertLocationDispatchInfo, topk_ids_logical_to_physical, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.layers.moe import ( + get_moe_runner_backend, + should_use_flashinfer_trtllm_moe, +) from sglang.srt.utils import ( cpu_has_amx_support, get_bool_env_var, @@ -43,6 +55,7 @@ try: from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing except ImportError: pass +logger = logging.getLogger(__name__) _is_cuda = is_cuda() @@ -65,13 +78,48 @@ if _use_aiter: if _is_npu: import torch_npu +# -------------------------------- TopKConfig --------------------------------------- + + +@dataclass +class TopKConfig: + top_k: int + use_grouped_topk: bool = False + topk_group: int = 0 + num_expert_group: int = 0 + renormalize: bool = True + num_fused_shared_experts: int = 0 + custom_routing_function: Optional[Callable] = None + correction_bias: Optional[torch.Tensor] = None + torch_native: bool = False + routed_scaling_factor: Optional[float] = None + apply_routed_scaling_factor_on_output: bool = False + # -------------------------------- TopKOutput --------------------------------------- +class TopKOutputChecker: + + @staticmethod + def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]: + return topk_output.format.is_standard() + + @staticmethod + def format_is_triton_kernel( + topk_output: TopKOutput, + ) -> TypeGuard[TritonKernelTopKOutput]: + return topk_output.format.is_triton_kernel() + + @staticmethod + def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]: + return topk_output.format.is_bypassed() + + class TopKOutputFormat(Enum): STANDARD = auto() TRITON_KERNEL = auto() + BYPASSED = auto() def is_standard(self) -> bool: return self == TopKOutputFormat.STANDARD @@ -79,6 +127,9 @@ class TopKOutputFormat(Enum): def is_triton_kernel(self) -> bool: return self == TopKOutputFormat.TRITON_KERNEL + def is_bypassed(self) -> bool: + return self == TopKOutputFormat.BYPASSED + @runtime_checkable class TopKOutput(Protocol): @@ -114,6 +165,20 @@ class TritonKernelTopKOutput(NamedTuple): return TopKOutputFormat.TRITON_KERNEL +class BypassedTopKOutput(NamedTuple): + """Bypassed top-k output format.""" + + hidden_states: torch.Tensor + router_logits: torch.Tensor + topk_config: TopKConfig + num_token_non_padded: Optional[torch.Tensor] = None + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None + + @property + def format(self) -> TopKOutputFormat: + return TopKOutputFormat.BYPASSED + + # -------------------------------- TopK --------------------------------------- @@ -124,8 +189,8 @@ class TopK(CustomOp): top_k: int, *, use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + topk_group: int = 0, + num_expert_group: int = 0, renormalize: bool = True, num_fused_shared_experts: int = 0, custom_routing_function: Optional[Callable] = None, @@ -136,19 +201,23 @@ class TopK(CustomOp): # NOTE: scoring_func is not used for now, but we keep it for future use # see https://github.com/sgl-project/sglang/pull/4505 for more details super().__init__() + if use_grouped_topk: assert num_expert_group is not None and topk_group is not None - self.top_k = top_k - self.use_grouped_topk = use_grouped_topk - self.renormalize = renormalize - self.topk_group = topk_group - self.num_expert_group = num_expert_group - self.num_fused_shared_experts = num_fused_shared_experts - self.custom_routing_function = custom_routing_function - self.correction_bias = correction_bias - self.routed_scaling_factor = routed_scaling_factor - self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] + self.topk_config = TopKConfig( + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + num_fused_shared_experts=num_fused_shared_experts, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + routed_scaling_factor=routed_scaling_factor, + ) + + self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() def forward_native( self, @@ -158,20 +227,11 @@ class TopK(CustomOp): num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, ) -> TopKOutput: - torch_native = True + self.topk_config.torch_native = True return select_experts( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - custom_routing_function=self.custom_routing_function, - correction_bias=self.correction_bias, - torch_native=torch_native, - routed_scaling_factor=self.routed_scaling_factor, + topk_config=self.topk_config, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) @@ -187,24 +247,28 @@ class TopK(CustomOp): if self.use_triton_kernels: # renormalize=True is equivalent to sm_first=False routing_data, gather_idx, scatter_idx = routing( - router_logits, self.top_k, sm_first=not self.renormalize + router_logits, + self.topk_config.top_k, + sm_first=not self.topk_config.renormalize, ) return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx) + elif ( + should_use_flashinfer_trtllm_moe() + or get_moe_runner_backend().is_flashinfer_mxfp4() + ): + return BypassedTopKOutput( + hidden_states=hidden_states, + router_logits=router_logits, + topk_config=self.topk_config, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) else: - torch_native = False + self.topk_config.torch_native = False return select_experts( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - custom_routing_function=self.custom_routing_function, - correction_bias=self.correction_bias, - torch_native=torch_native, - routed_scaling_factor=self.routed_scaling_factor, + topk_config=self.topk_config, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) @@ -220,15 +284,7 @@ class TopK(CustomOp): return select_experts( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - custom_routing_function=self.custom_routing_function, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, + topk_config=self.topk_config, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) @@ -244,35 +300,29 @@ class TopK(CustomOp): global_num_experts = router_logits.shape[-1] # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: + if global_num_experts == 256 and self.topk_config.renormalize is False: + + routed_scaling_factor = self.topk_config.routed_scaling_factor or 1 router_logits = router_logits.to(torch.float32) + return torch_npu.npu_moe_gating_top_k( router_logits, - k=self.top_k, - bias=self.correction_bias.to(torch.float32), - k_group=self.topk_group, - group_count=self.num_expert_group, + k=self.topk_config.top_k, + bias=self.topk_config.correction_bias.to(torch.float32), + k_group=self.topk_config.topk_group, + group_count=self.topk_config.num_expert_group, group_select_mode=1, renorm=0, norm_type=1, - routed_scaling_factor=1, + routed_scaling_factor=routed_scaling_factor, eps=float(1e-20), ) else: - torch_native = True + self.topk_config.torch_native = True return select_experts( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - custom_routing_function=self.custom_routing_function, - correction_bias=self.correction_bias, - torch_native=torch_native, - routed_scaling_factor=self.routed_scaling_factor, + topk_config=self.topk_config, num_token_non_padded=num_token_non_padded, expert_location_dispatch_info=expert_location_dispatch_info, ) @@ -670,20 +720,23 @@ else: def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, - top_k: int, + topk_config: TopKConfig, *, - use_grouped_topk: bool = False, - renormalize: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - torch_native: bool = False, - routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, -) -> TopKOutput: +) -> StandardTopKOutput: + + top_k = topk_config.top_k + use_grouped_topk = topk_config.use_grouped_topk + topk_group = topk_config.topk_group + num_expert_group = topk_config.num_expert_group + renormalize = topk_config.renormalize + num_fused_shared_experts = topk_config.num_fused_shared_experts + custom_routing_function = topk_config.custom_routing_function + correction_bias = topk_config.correction_bias + torch_native = topk_config.torch_native + routed_scaling_factor = topk_config.routed_scaling_factor + router_logits, correction_bias = ( expert_location_dispatch.transform_select_experts_inputs( router_logits=router_logits, diff --git a/python/sglang/srt/layers/moe/utils.py b/python/sglang/srt/layers/moe/utils.py index f08b34e40..40bd10e23 100644 --- a/python/sglang/srt/layers/moe/utils.py +++ b/python/sglang/srt/layers/moe/utils.py @@ -1,55 +1,80 @@ +from __future__ import annotations + import importlib.util from enum import Enum from functools import lru_cache +from typing import TYPE_CHECKING, Optional from packaging import version as pkg_version -from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.utils import logger - -@lru_cache(maxsize=1) -def should_use_flashinfer_trtllm_moe(): - result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and ( - not importlib.util.find_spec("flashinfer") - or pkg_version.parse(__import__("flashinfer").__version__) - >= pkg_version.parse("0.2.9rc1") - ) - return result +if TYPE_CHECKING: + from sglang.srt.server_args import ServerArgs class MoeA2ABackend(Enum): - STANDARD = ("standard", "none") + NONE = "none" DEEPEP = "deepep" @classmethod def _missing_(cls, value): if value is None: - return cls.STANDARD + return cls.NONE for member in cls: - if value in member.value: + if value == member.value: return member raise ValueError(f"No {cls.__name__} member for value {value}") + def is_none(self): + return self == MoeA2ABackend.NONE + def is_deepep(self): return self == MoeA2ABackend.DEEPEP - def is_standard(self): - return self == MoeA2ABackend.STANDARD + +class MoeRunnerBackend(Enum): + + AUTO = "auto" + TRITON = "triton" + TRITON_KERNEL = "triton_kernel" + FLASHINFER = "flashinfer_trtllm" + FLASHINFER_CUTLASS = "flashinfer_cutlass" + FLASHINFER_MXFP4 = "flashinfer_mxfp4" + + def is_auto(self): + return self == MoeRunnerBackend.AUTO + + def is_triton(self): + return self == MoeRunnerBackend.TRITON + + def is_triton_kernel(self): + return self == MoeRunnerBackend.TRITON_KERNEL + + def is_flashinfer_trtllm(self): + return self == MoeRunnerBackend.FLASHINFER + + def is_flashinfer_cutlass(self): + return self == MoeRunnerBackend.FLASHINFER_CUTLASS + + def is_flashinfer_mxfp4(self): + return self == MoeRunnerBackend.FLASHINFER_MXFP4 class DeepEPMode(Enum): + NORMAL = "normal" LOW_LATENCY = "low_latency" AUTO = "auto" - def enable_normal(self): + def enable_normal(self) -> bool: return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO] - def enable_low_latency(self): + def enable_low_latency(self) -> bool: return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO] - def resolve(self, is_extend_in_batch: bool): + def resolve(self, is_extend_in_batch: bool) -> DeepEPMode: if self != DeepEPMode.AUTO: return self @@ -57,3 +82,96 @@ class DeepEPMode(Enum): return DeepEPMode.NORMAL else: return DeepEPMode.LOW_LATENCY + + def is_normal(self) -> bool: + return self == DeepEPMode.NORMAL + + def is_low_latency(self) -> bool: + return self == DeepEPMode.LOW_LATENCY + + def is_auto(self) -> bool: + return self == DeepEPMode.AUTO + + +MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None +MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None +DEEPEP_MODE: Optional[DeepEPMode] = None +IS_TBO_ENABLED: Optional[bool] = None +TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None +DEEPEP_CONFIG: Optional[str] = None + + +def initialize_moe_config(server_args: ServerArgs): + global MOE_A2A_BACKEND + global MOE_RUNNER_BACKEND + global DEEPEP_MODE + global DEEPEP_CONFIG + global IS_TBO_ENABLED + global TBO_TOKEN_DISTRIBUTION_THRESHOLD + + MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend) + MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend) + DEEPEP_MODE = DeepEPMode(server_args.deepep_mode) + DEEPEP_CONFIG = server_args.deepep_config or "" + IS_TBO_ENABLED = server_args.enable_two_batch_overlap + TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold + + +def get_moe_a2a_backend() -> MoeA2ABackend: + global MOE_A2A_BACKEND + if MOE_A2A_BACKEND is None: + logger.warning("MOE_A2A_BACKEND is not initialized, using default backend") + MOE_A2A_BACKEND = MoeA2ABackend(None) + return MOE_A2A_BACKEND + + +def get_moe_runner_backend() -> MoeRunnerBackend: + global MOE_RUNNER_BACKEND + if MOE_RUNNER_BACKEND is None: + logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend") + MOE_RUNNER_BACKEND = MoeRunnerBackend("triton") + return MOE_RUNNER_BACKEND + + +def get_deepep_mode() -> DeepEPMode: + global DEEPEP_MODE + if DEEPEP_MODE is None: + logger.warning("DEEPEP_MODE is not initialized, using auto mode") + DEEPEP_MODE = DeepEPMode("auto") + return DEEPEP_MODE + + +def get_deepep_config() -> str: + global DEEPEP_CONFIG + if DEEPEP_CONFIG is None: + logger.warning("DEEPEP_CONFIG is not initialized, using default config") + DEEPEP_CONFIG = "" + return DEEPEP_CONFIG + + +def is_tbo_enabled() -> bool: + global IS_TBO_ENABLED + if IS_TBO_ENABLED is None: + logger.warning("IS_TBO_ENABLED is not initialized, using False") + IS_TBO_ENABLED = False + return IS_TBO_ENABLED + + +def get_tbo_token_distribution_threshold() -> float: + global TBO_TOKEN_DISTRIBUTION_THRESHOLD + if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None: + logger.warning( + "TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48" + ) + TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48 + return TBO_TOKEN_DISTRIBUTION_THRESHOLD + + +@lru_cache(maxsize=1) +def should_use_flashinfer_trtllm_moe(): + result = get_moe_runner_backend().is_flashinfer_trtllm() and ( + not importlib.util.find_spec("flashinfer") + or pkg_version.parse(__import__("flashinfer").__version__) + >= pkg_version.parse("0.2.9rc1") + ) + return result diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 5eee0c98e..19deb7dd1 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter if TYPE_CHECKING: - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig + from sglang.srt.layers.moe.topk import StandardTopKOutput from sglang.srt.utils import is_cuda, is_hip @@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - topk_output: TopKOutput, - *, - activation: str = "silu", - **kwargs, + topk_output: StandardTopKOutput, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - - assert activation == "silu", "Only SiLU activation is supported." + assert ( + moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." # The input must currently be float16 orig_dtype = x.dtype diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index bf24c3701..ec2b4edb1 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -9,6 +9,7 @@ import torch from torch import nn if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput @@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index 62dc45ad9..a5966c4d5 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn import Module @@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): layer.w13_weight, layer.w2_weight, topk_output=topk_output, - inplace=inplace, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, + moe_runner_config=moe_runner_config, use_int8_w8a8=True, w1_scale=(layer.w13_weight_scale_inv), w2_scale=(layer.w2_weight_scale_inv), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c2e908f8c..c10515107 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -23,6 +23,7 @@ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( CompressedTensorsConfig, @@ -269,12 +270,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton import fused_experts @@ -283,8 +279,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): layer.w13_weight, layer.w2_weight, topk_output=topk_output, - inplace=inplace, - activation=activation, + moe_runner_config=moe_runner_config, use_fp8_w8a8=True, per_channel_quant=self.weight_quant.strategy == QuantizationStrategy.CHANNEL, @@ -292,8 +287,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - apply_router_weight_on_input=apply_router_weight_on_input, - routed_scaling_factor=routed_scaling_factor, ) @@ -601,12 +594,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - **kwargs, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - assert activation == "silu", "Only SiLU activation is supported." + assert ( + moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." topk_weights, topk_ids, router_logits = topk_output diff --git a/python/sglang/srt/layers/quantization/fp4.py b/python/sglang/srt/layers/quantization/fp4.py index 68d463cc3..a03f200c8 100644 --- a/python/sglang/srt/layers/quantization/fp4.py +++ b/python/sglang/srt/layers/quantization/fp4.py @@ -41,6 +41,7 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput logger = logging.getLogger(__name__) @@ -220,22 +221,10 @@ class MxFp4LinearMethod(LinearMethodBase): return out -class MxFp4MoEMethod: - def __new__(cls, *args, **kwargs): - if not hasattr(cls, "_initialized"): - original_init = cls.__init__ - new_cls = type( - cls.__name__, - (FusedMoEMethodBase,), - { - "__init__": original_init, - **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, - }, - ) - obj = super(new_cls, new_cls).__new__(new_cls) - obj.__init__(*args, **kwargs) - return obj - return super().__new__(cls) +class MxFp4MoEMethod(FusedMoEMethodBase): + + def __init__(self, quant_config: Mxfp4Config): + self.quant_config = quant_config @staticmethod def get_moe_method( @@ -364,12 +353,7 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: topk_weights, topk_ids, _ = topk_output @@ -383,7 +367,9 @@ class W4A4MXFp4MoEDynamicMethod(MxFp4MoEMethod): w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, activation=( - ActivationType.Silu if activation == "silu" else ActivationType.Gelu + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu ), doweight_stage1=False, ) @@ -497,12 +483,7 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: topk_weights, topk_ids, _ = topk_output @@ -516,7 +497,9 @@ class W4A4MXFp4MoEStaticMethod(MxFp4MoEMethod): w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, activation=( - ActivationType.Silu if activation == "silu" else ActivationType.Gelu + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu ), doweight_stage1=False, ) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 956264fc9..14ce92f36 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -79,6 +79,7 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config @@ -982,12 +983,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -996,7 +992,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( - apply_router_weight_on_input, topk_weights, x + moe_runner_config.apply_router_weight_on_input, topk_weights, x ) return torch.ops.sgl_kernel.fused_experts_cpu( @@ -1021,8 +1017,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer, x, topk_output, - activation, - no_combine, + moe_runner_config.activation, + moe_runner_config.no_combine, ) if ret is not None: return ret @@ -1060,8 +1056,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): use_fp8_blockscale=True, ) # TODO: Fuse into select_experts - if routed_scaling_factor is not None: - output *= routed_scaling_factor + if moe_runner_config.routed_scaling_factor is not None: + output *= moe_runner_config.routed_scaling_factor return output # Expert fusion with FP8 quantization return fused_experts( @@ -1069,9 +1065,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): layer.w13_weight, layer.w2_weight, topk_output=topk_output, - inplace=inplace and not no_combine, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, + moe_runner_config=moe_runner_config, use_fp8_w8a8=True, w1_scale=( layer.w13_weight_scale_inv @@ -1084,26 +1078,32 @@ class Fp8MoEMethod(FusedMoEMethodBase): a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) def apply_with_router_logits( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - *, - activation: str = "silu", - routed_scaling_factor: Optional[float] = None, + topk_output: TopKOutput, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: + + activation = moe_runner_config.activation + routed_scaling_factor = moe_runner_config.routed_scaling_factor + + from flashinfer.fused_moe import trtllm_fp8_block_scale_moe + + from sglang.srt.layers.moe.topk import TopKOutputChecker + + assert TopKOutputChecker.format_is_bypassed(topk_output) + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config assert ( activation == "silu" ), "Only silu is supported for flashinfer blockscale fp8 moe" a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1]) # NOTE: scales of hidden states have to be transposed! a_sf_t = a_sf.t().contiguous() - from flashinfer.fused_moe import trtllm_fp8_block_scale_moe return trtllm_fp8_block_scale_moe( routing_logits=router_logits.to(torch.float32), @@ -1115,9 +1115,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): gemm2_weights=layer.w2_weight, gemm2_weights_scale=layer.w2_weight_scale_inv, num_experts=layer.num_experts, - top_k=layer.top_k, - n_group=layer.num_expert_group, - topk_group=layer.topk_group, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, intermediate_size=layer.w2_weight.shape[2], local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, local_num_experts=layer.num_local_experts, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index f97a574c8..259d0098b 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -113,6 +113,7 @@ def normalize_e4m3fn_to_e4m3fnuz( return weight, weight_scale, input_scale +# TODO(ch-wan): define these backends in --moe-runner-backend def cutlass_block_fp8_supported() -> bool: if not get_bool_env_var("SGLANG_SUPPORT_CUTLASS_BLOCK_FP8"): return False diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index 967fd0550..c770708b0 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -44,6 +44,7 @@ from sglang.srt.layers.quantization.utils import ( ) if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.utils import is_cuda @@ -1056,13 +1057,13 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - **kwargs, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: # Delay the import to avoid circular dependency - assert activation == "silu", "Only SiLU activation is supported." + assert ( + moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." # The input must currently be float16 orig_dtype = x.dtype diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py index 28f232332..d76b900ae 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils.py +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -28,6 +28,7 @@ from sglang.srt.utils import get_device_capability, is_cuda if TYPE_CHECKING: from sglang.srt.layers.linear import LinearBase + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE try: from vllm import _custom_ops as ops @@ -216,13 +217,13 @@ def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: )[0] -def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: +def check_moe_marlin_supports_layer(layer: FusedMoE, group_size: int) -> bool: hidden_size = layer.hidden_size intermediate_size_per_partition = layer.intermediate_size_per_partition # apply_router_weight_on_input is not supported for moe marlin - supports_router_weight = not layer.apply_router_weight_on_input + supports_router_weight = not layer.moe_runner_config.apply_router_weight_on_input # moe marlin requires the activation to be silu - supports_activation = layer.activation == "silu" + supports_activation = layer.moe_runner_config.activation == "silu" # gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size) # down: (n, k) = (hidden_size, intermediate_size_per_partition) diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 103f675d2..a77d504a2 100755 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -7,8 +7,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter +from sglang.srt.layers.moe import should_use_flashinfer_trtllm_moe from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType -from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, @@ -30,10 +30,11 @@ from sglang.srt.layers.quantization.utils import ( requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import is_cuda, next_power_of_2 if TYPE_CHECKING: + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput if is_cuda(): @@ -422,12 +423,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -436,15 +432,13 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): layer.w13_weight, layer.w2_weight, topk_output=topk_output, - inplace=inplace, - activation=activation, + moe_runner_config=moe_runner_config, use_fp8_w8a8=True, per_channel_quant=False, # ModelOpt uses per-tensor quantization w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - no_combine=no_combine, ) @@ -741,8 +735,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): @property def enable_flashinfer_cutlass_moe(self) -> bool: + from sglang.srt.layers.moe import get_moe_runner_backend + """Access the global enable_flashinfer_cutlass_moe setting.""" - return global_server_args_dict.get("enable_flashinfer_cutlass_moe", False) + return get_moe_runner_backend().is_flashinfer_cutlass() def create_weights( self, @@ -1160,21 +1156,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): def apply( self, - layer: torch.nn.Module, + layer: FusedMoE, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - ep_rank: Optional[int] = None, - ep_size: Optional[int] = None, - tp_rank: Optional[int] = None, - tp_size: Optional[int] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - assert activation == "silu", "Only SiLU activation is supported." + assert ( + moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." # Check if this is a FlashInferFP4MoE layer that should handle its own forward if hasattr(layer, "gemm1_weights_fp4_shuffled"): @@ -1183,7 +1172,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): if self.enable_flashinfer_cutlass_moe: assert ( - not apply_router_weight_on_input + not moe_runner_config.apply_router_weight_on_input ), "apply_router_weight_on_input is not supported for Flashinfer" # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # and fp4 quantized weights loaded from the checkpoint @@ -1205,14 +1194,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): layer.w2_blockscale_swizzled.view(torch.int32), layer.g2_alphas, ], - ep_size=ep_size, - ep_rank=ep_rank, - tp_size=tp_size, - tp_rank=tp_rank, + ep_size=layer.moe_ep_size, + ep_rank=layer.moe_ep_rank, + tp_size=layer.moe_tp_size, + tp_rank=layer.moe_tp_rank, tune_max_num_tokens=next_power_of_2(x.shape[0]), )[0] - if routed_scaling_factor is not None: - output *= routed_scaling_factor + if moe_runner_config.routed_scaling_factor is not None: + output *= moe_runner_config.routed_scaling_factor return output from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 @@ -1231,8 +1220,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): topk_weights=topk_weights, topk_ids=topk_ids, params=layer.cutlass_moe_params, - apply_router_weight_on_input=apply_router_weight_on_input, + apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input, ).to(x.dtype) - if routed_scaling_factor is not None: - output *= routed_scaling_factor + if moe_runner_config.routed_scaling_factor is not None: + output *= moe_runner_config.routed_scaling_factor return output diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index fbbf11066..7f2c78cbb 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -22,6 +22,7 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs logger = logging.getLogger(__name__) if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput @@ -353,17 +354,14 @@ class MoeWNA16Method(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: # avoid circular import from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - assert activation == "silu", "Only SiLU activation is supported." + assert ( + moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp @@ -373,8 +371,7 @@ class MoeWNA16Method(FusedMoEMethodBase): layer.w13_qweight, layer.w2_qweight, topk_output=topk_output, - inplace=inplace, - apply_router_weight_on_input=apply_router_weight_on_input, + moe_runner_config=moe_runner_config, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, w1_scale=layer.w13_scales, @@ -382,8 +379,6 @@ class MoeWNA16Method(FusedMoEMethodBase): w1_zp=layer.w13_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None, block_shape=[0, layer.group_size], - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) @staticmethod @@ -486,16 +481,16 @@ class MoeWNA16Method(FusedMoEMethodBase): ) if "w13_qzeros" in weight_name: - tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[ - tp_rank - ] + tensor = loaded_weight.view( + layer.moe_tp_size, -1, loaded_weight.size(1) + )[tp_rank] if shard_id == "w1": param.data[expert_id, : shard_size // 2] = tensor else: param.data[expert_id, shard_size // 2 :] = tensor elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( - loaded_weight.size(0), layer.tp_size, -1 + loaded_weight.size(0), layer.moe_tp_size, -1 )[:, tp_rank] else: weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) diff --git a/python/sglang/srt/layers/quantization/mxfp4.py b/python/sglang/srt/layers/quantization/mxfp4.py index 46db5f03f..5eaa21d1e 100644 --- a/python/sglang/srt/layers/quantization/mxfp4.py +++ b/python/sglang/srt/layers/quantization/mxfp4.py @@ -16,14 +16,13 @@ from __future__ import annotations -import importlib.util import logging from typing import TYPE_CHECKING, List, Optional import torch -import triton.language as tl from torch.nn.parameter import Parameter +from sglang.srt.layers.moe.utils import get_moe_runner_backend from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, @@ -31,7 +30,6 @@ from sglang.srt.layers.quantization.base_config import ( ) from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.utils import is_sm100_supported -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.utils import ( direct_register_custom_op, get_bool_env_var, @@ -60,6 +58,7 @@ if is_flashinfer_available(): logger = logging.getLogger(__name__) if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput OCP_MX_BLOCK_SIZE = 32 @@ -218,15 +217,13 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): self, prefix: str, ): - from sglang.srt.managers.schedule_batch import global_server_args_dict - super().__init__() self.prefix = prefix self.topk_indices_dtype = None - self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"] + self.use_triton_kernels = get_moe_runner_backend().is_triton_kernel() self.with_bias = False - self.use_flashinfer = global_server_args_dict["enable_flashinfer_mxfp4_moe"] + self.use_flashinfer = get_moe_runner_backend().is_flashinfer_mxfp4() self.triton_kernel_moe_forward = None self.triton_kernel_moe_with_bias_forward = None @@ -348,6 +345,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): logger, f"Shuffling MoE weights for FlashInfer MXFP4 moe kernel (layer: {self.prefix}), it might take a while...", ) + # TODO: these values are hardcoded for now, we need to get them from the model layer.gemm1_alpha = Parameter( torch.tensor([1.702] * self.num_experts, dtype=torch.float32).cuda(), requires_grad=False, @@ -573,14 +571,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: if self.use_flashinfer: # Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance @@ -637,9 +628,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): b1=layer.w13_weight_bias, b2=layer.w2_weight_bias, topk_output=topk_output, - activation=activation, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, + moe_runner_config=moe_runner_config, ) else: return self.triton_kernel_moe_forward( @@ -647,6 +636,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w1=layer.w13_weight, w2=layer.w2_weight, topk_output=topk_output, + moe_runner_config=moe_runner_config, ) else: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -656,13 +646,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): w1=layer.w13_weight, w2=layer.w2_weight, topk_output=topk_output, + moe_runner_config=moe_runner_config, b1=layer.w13_weight_bias, b2=layer.w2_weight_bias, - inplace=inplace, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, ) diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 9c33e3173..67d3ce327 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -1,7 +1,7 @@ from __future__ import annotations -import importlib -from typing import TYPE_CHECKING, Callable, List, Optional +import importlib.util +from typing import TYPE_CHECKING, List, Optional import torch import torch.nn.functional as F @@ -24,7 +24,7 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: - from sglang.srt.layers.moe.ep_moe.layer import EPMoE + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None @@ -221,31 +221,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - kwargs = {} - if activation_alpha is not None: - kwargs["activation_alpha"] = activation_alpha - if swiglu_limit is not None: - kwargs["swiglu_limit"] = swiglu_limit return self.forward( x=x, layer=layer, topk_output=topk_output, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - inplace=inplace, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - **kwargs, + moe_runner_config=moe_runner_config, ) def forward_cuda( @@ -253,18 +236,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, - activation_alpha: Optional[float] = None, - swiglu_limit: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: if self.use_triton_kernels: if self.with_bias: + assert self.triton_kernel_moe_with_bias_forward is not None return self.triton_kernel_moe_with_bias_forward( hidden_states=x, w1=layer.w13_weight, @@ -272,24 +249,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): b1=layer.w13_weight_bias, b2=layer.w2_weight_bias, topk_output=topk_output, - activation=activation, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, + moe_runner_config=moe_runner_config, w1_pcg=None, w2_pcg=None, ) else: + assert self.triton_kernel_moe_forward is not None return self.triton_kernel_moe_forward( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_output=topk_output, + moe_runner_config=moe_runner_config, ) else: if _use_aiter: - assert not no_combine, "unsupported" + assert not moe_runner_config.no_combine, "unsupported" topk_weights, topk_ids, _ = topk_output - if apply_router_weight_on_input: + if moe_runner_config.apply_router_weight_on_input: assert ( topk_weights.dim() == 2 ), "`topk_weights` should be in shape (num_tokens, topk)" @@ -309,7 +286,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_ids, activation=( ActivationType.Silu - if activation == "silu" + if moe_runner_config.activation == "silu" else ActivationType.Gelu ), ) @@ -325,13 +302,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): b1=getattr(layer, "w13_weight_bias", None), b2=getattr(layer, "w2_weight_bias", None), topk_output=topk_output, - inplace=inplace and not no_combine, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, - activation_alpha=activation_alpha, - swiglu_limit=swiglu_limit, + moe_runner_config=moe_runner_config, ) def forward_cpu( @@ -339,21 +310,21 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: - assert activation == "silu", f"activation = {activation} is not supported." + assert ( + moe_runner_config.activation == "silu" + ), f"activation = {moe_runner_config.activation} is not supported." - if use_intel_amx_backend(layer) and not apply_router_weight_on_input: + if ( + use_intel_amx_backend(layer) + and not moe_runner_config.apply_router_weight_on_input + ): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( - apply_router_weight_on_input, topk_weights, x + moe_runner_config.apply_router_weight_on_input, topk_weights, x ) return torch.ops.sgl_kernel.fused_experts_cpu( x, @@ -378,11 +349,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer, x, topk_output, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - inplace=inplace, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, + moe_runner_config, ) def forward_npu( @@ -390,12 +357,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_native import moe_forward_native @@ -403,11 +365,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): layer, x, topk_output, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - inplace=inplace, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, + moe_runner_config, ) def forward_tpu(self, *args, **kwargs) -> torch.Tensor: diff --git a/python/sglang/srt/layers/quantization/w4afp8.py b/python/sglang/srt/layers/quantization/w4afp8.py index 7a471870a..9be54d05a 100644 --- a/python/sglang/srt/layers/quantization/w4afp8.py +++ b/python/sglang/srt/layers/quantization/w4afp8.py @@ -18,7 +18,9 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: - from sglang.srt.layers.moe.ep_moe.layer import EPMoE, TopKOutput + from sglang.srt.layers.moe import MoeRunnerConfig + from sglang.srt.layers.moe.ep_moe.layer import EPMoE + from sglang.srt.layers.moe.topk import StandardTopKOutput ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -280,11 +282,8 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): self, layer: EPMoE, x: torch.Tensor, - topk_output: TopKOutput, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - routed_scaling_factor: Optional[float] = None, - **kwargs, + topk_output: StandardTopKOutput, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: # TODO(ch-wan): move it out of this class @@ -324,6 +323,6 @@ class W4AFp8MoEMethod(FusedMoEMethodBase): layer.w13_input_scale, layer.w2_input_scale, ) - if routed_scaling_factor is not None: - output *= routed_scaling_factor + if moe_runner_config.routed_scaling_factor is not None: + output *= moe_runner_config.routed_scaling_factor return output diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index e486fef0b..5e1aa41a6 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -26,7 +26,8 @@ from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.utils import set_weight_attrs if TYPE_CHECKING: - from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig + from sglang.srt.layers.moe.topk import StandardTopKOutput _is_fp8_fnuz = is_fp8_fnuz() @@ -269,13 +270,8 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + topk_output: StandardTopKOutput, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -284,15 +280,11 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): layer.w13_weight, layer.w2_weight, topk_output=topk_output, - inplace=inplace, - apply_router_weight_on_input=apply_router_weight_on_input, - activation=activation, + moe_runner_config=moe_runner_config, use_fp8_w8a8=True, per_channel_quant=True, w1_scale=(layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 843fffe7b..abcf334e0 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -49,6 +49,7 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: + from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopKOutput _is_cuda = is_cuda() @@ -487,12 +488,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, - *, - activation: str = "silu", - apply_router_weight_on_input: bool = False, - inplace: bool = True, - no_combine: bool = False, - routed_scaling_factor: Optional[float] = None, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -501,7 +497,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( - apply_router_weight_on_input, topk_weights, x + moe_runner_config.apply_router_weight_on_input, topk_weights, x ) return torch.ops.sgl_kernel.fused_experts_cpu( x, @@ -525,17 +521,13 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): layer.w13_weight, layer.w2_weight, topk_output=topk_output, - inplace=inplace, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, + moe_runner_config=moe_runner_config, use_int8_w8a8=True, per_channel_quant=True, w1_scale=(layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - no_combine=no_combine, - routed_scaling_factor=routed_scaling_factor, ) @@ -982,7 +974,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): layer, x, topk_output: TopKOutput, - **kwargs, + moe_runner_config: MoeRunnerConfig, ) -> torch.Tensor: topk_weights, topk_ids, _ = topk_output diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 8b1b11bdf..770fd8cee 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -52,6 +52,7 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ScheduleBatchDisaggregationDecodeMixin, ) from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank +from sglang.srt.layers.moe import is_tbo_enabled from sglang.srt.mem_cache.allocator import ( BaseTokenToKVPoolAllocator, SWATokenToKVPoolAllocator, @@ -84,17 +85,10 @@ GLOBAL_SERVER_ARGS_KEYS = [ "device", "disable_chunked_prefix_cache", "disable_radix_cache", - "enable_two_batch_overlap", - "tbo_token_distribution_threshold", "enable_dp_lm_head", - "moe_a2a_backend", - "deepep_mode", - "enable_flashinfer_cutlass_moe", - "enable_flashinfer_trtllm_moe", "enable_flashinfer_allreduce_fusion", "moe_dense_tp_size", "ep_dispatch_algorithm", - "deepep_config", "ep_num_redundant_experts", "enable_nan_detection", "flashinfer_mla_disable_ragged", @@ -107,8 +101,6 @@ GLOBAL_SERVER_ARGS_KEYS = [ "triton_attention_reduce_in_fp32", "num_reserved_decode_tokens", "weight_loader_disable_mmap", - "enable_triton_kernel_moe", - "enable_flashinfer_mxfp4_moe", "enable_multimodal", "enable_symm_mem", "quantization", diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 95a529c89..04e6f13b0 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -64,7 +64,7 @@ from sglang.srt.hf_transformers_utils import ( ) from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend +from sglang.srt.layers.moe import initialize_moe_config from sglang.srt.managers.io_struct import ( AbortReq, CloseSessionReqInput, @@ -245,6 +245,9 @@ class Scheduler( ) ) + # Init model config + self.model_config = ModelConfig.from_server_args(server_args) + # Init inter-process communication context = zmq.Context(2) self.idle_sleeper = None @@ -292,6 +295,9 @@ class Scheduler( # Init tokenizer self.init_tokenizer() + # Init moe config + self.init_moe_config() + # Set reasoning_parser and think_end_id if --reasoning_parser is enabled if self.server_args.reasoning_parser and self.tokenizer: reasoning_parser = ReasoningParser( @@ -538,8 +544,6 @@ class Scheduler( def init_tokenizer(self): server_args = self.server_args - - self.model_config = ModelConfig.from_server_args(server_args) self.is_generation = self.model_config.is_generation if server_args.skip_tokenizer_init: @@ -761,6 +765,10 @@ class Scheduler( # The prefill requests that are in the middle of kv sending self.disagg_prefill_inflight_queue: List[Req] = [] + def init_moe_config(self): + if hasattr(self.model_config.hf_config, "num_experts_per_tok"): + initialize_moe_config(self.server_args) + @DynamicGradMode() def event_loop_normal(self): """A normal scheduler loop.""" @@ -1823,11 +1831,6 @@ class Scheduler( disable_cuda_graph=self.server_args.disable_cuda_graph, spec_algorithm=self.spec_algorithm, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, - enable_two_batch_overlap=self.server_args.enable_two_batch_overlap, - enable_deepep_moe=MoeA2ABackend( - self.server_args.moe_a2a_backend - ).is_deepep(), - deepep_mode=DeepEPMode(self.server_args.deepep_mode), require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), disable_overlap_schedule=self.server_args.disable_overlap_schedule, ) @@ -1922,9 +1925,6 @@ class Scheduler( disable_cuda_graph: bool, spec_algorithm, speculative_num_draft_tokens, - enable_two_batch_overlap: bool, - enable_deepep_moe: bool, - deepep_mode: DeepEPMode, require_mlp_tp_gather: bool, disable_overlap_schedule: bool, ): @@ -1972,9 +1972,6 @@ class Scheduler( is_extend_in_batch, *tbo_preparer.prepare_all_gather( local_batch, - deepep_mode, - enable_deepep_moe, - enable_two_batch_overlap, ), ], dtype=torch.int64, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 5222bff0a..717a16ef0 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -60,7 +60,6 @@ from sglang.srt.layers.dp_attention import ( initialize_dp_attention, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend from sglang.srt.layers.quantization import ( deep_gemm_wrapper, monkey_patch_isinstance_for_vllm_base_layer, @@ -219,8 +218,6 @@ class ModelRunner: # TODO it is indeed not a "server args" "use_mla_backend": self.use_mla_backend, "speculative_algorithm": self.spec_algorithm, - "moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend), - "deepep_mode": DeepEPMode(server_args.deepep_mode), } ) diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 15cef015c..74de384b3 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -32,7 +32,9 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.fused_moe_triton import fused_moe +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -104,6 +106,11 @@ class DbrxExperts(nn.Module): self.params_dtype = params_dtype self.router = DbrxRouter(config, self.params_dtype) + self.topk = TopK( + self.top_k, + renormalize=True, + ) + self.moe_runner_config = MoeRunnerConfig(inplace=True) self.ws = nn.Parameter( torch.empty( self.num_total_experts, @@ -169,14 +176,13 @@ class DbrxExperts(nn.Module): hidden_states = hidden_states.view(-1, self.d_model) # router_logits: (num_tokens, n_experts) router_logits = self.router(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = fused_moe( hidden_states, self.ws, self.w2s, - router_logits, - self.top_k, - renormalize=True, - inplace=True, + topk_output, + self.moe_runner_config, ) if self.tp_size > 1: @@ -293,7 +299,7 @@ class DbrxFusedNormAttention(nn.Module): position_ids: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: residual = hidden_states hidden_states = self.norm_1(hidden_states) x = self.attn( diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index f2f0d0344..ef431e00d 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -37,6 +37,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import fused_moe +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -180,7 +181,7 @@ class DeepseekMoE(nn.Module): w1=self.w1, w2=self.w2, topk_output=topk_output, - inplace=True, + moe_runner_config=MoeRunnerConfig(inplace=True), ) if self.config.n_shared_experts is not None: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 90efc4067..8d51d7823 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -50,7 +50,6 @@ from sglang.srt.layers.communicator import ( from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, - get_local_attention_dp_size, is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm @@ -61,9 +60,10 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK -from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( @@ -336,30 +336,6 @@ class DeepseekV2MoE(nn.Module): quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, prefix=add_prefix("experts", prefix), - **( - dict(deepep_mode=global_server_args_dict["deepep_mode"]) - if global_server_args_dict["moe_a2a_backend"].is_deepep() - else {} - ), - # Additional args for FusedMoE - **( - dict( - enable_flashinfer_cutlass_moe=True, - ) - if global_server_args_dict["enable_flashinfer_cutlass_moe"] - else {} - ), - **( - dict( - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - ) - if should_use_flashinfer_trtllm_moe() - else {} - ), ) self.shared_experts_is_int8 = False @@ -377,7 +353,7 @@ class DeepseekV2MoE(nn.Module): prefix=add_prefix("shared_experts", prefix), **( dict(tp_rank=0, tp_size=1) - if global_server_args_dict["moe_a2a_backend"].is_deepep() + if get_moe_a2a_backend().is_deepep() else {} ), ) @@ -407,7 +383,7 @@ class DeepseekV2MoE(nn.Module): self.top_k = config.num_experts_per_tok - if global_server_args_dict["moe_a2a_backend"].is_deepep(): + if get_moe_a2a_backend().is_deepep(): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( @@ -431,12 +407,12 @@ class DeepseekV2MoE(nn.Module): num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, - deepep_mode=global_server_args_dict["deepep_mode"], + deepep_mode=get_deepep_mode(), async_finish=True, return_recv_hook=True, ) - self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() + self._enable_deepep_moe = get_moe_a2a_backend().is_deepep() def get_moe_weights(self): return [ @@ -484,13 +460,7 @@ class DeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) kwargs = {"hidden_states": hidden_states} - - # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple - # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput - if should_use_flashinfer_trtllm_moe(): - kwargs["topk_output"] = (self.topk, router_logits) - else: - kwargs["topk_output"] = self.topk(hidden_states, router_logits) + kwargs["topk_output"] = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(**kwargs) if not _is_cuda: @@ -520,13 +490,7 @@ class DeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) kwargs = {"hidden_states": hidden_states} - - # FlashInferFP4MoE (TRTLLM path) expects (TopK, router_logits) tuple - # Regular FusedMoE (CUTLASS path) expects StandardTopKOutput - if should_use_flashinfer_trtllm_moe(): - kwargs["topk_output"] = (self.topk, router_logits) - else: - kwargs["topk_output"] = self.topk(hidden_states, router_logits) + kwargs["topk_output"] = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(**kwargs) if not _is_cuda and not _use_aiter: @@ -2478,17 +2442,15 @@ class DeepseekV2ForCausalLM(nn.Module): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts + self.num_fused_shared_experts, ) if self.quant_config and self.quant_config.get_name() == "w4afp8": - expert_params_mapping += ( - get_moe_impl_class().make_expert_input_scale_params_mapping( - num_experts=self.config.n_routed_experts - ) + expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping( + num_experts=self.config.n_routed_experts ) # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None diff --git a/python/sglang/srt/models/ernie4.py b/python/sglang/srt/models/ernie4.py index 6cd41f399..78a7b4b94 100644 --- a/python/sglang/srt/models/ernie4.py +++ b/python/sglang/srt/models/ernie4.py @@ -31,13 +31,13 @@ from sglang.srt.layers.communicator import enable_moe_dense_fully_dp from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_v2 import DeepseekV2MLP as Ernie4MLP @@ -361,7 +361,7 @@ class Ernie4_5_ForCausalLM(nn.Module): class Ernie4_5_MoeForCausalLM(Ernie4_5_ForCausalLM): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/glm4_moe.py b/python/sglang/srt/models/glm4_moe.py index e0f0b373d..6e4b16e78 100644 --- a/python/sglang/srt/models/glm4_moe.py +++ b/python/sglang/srt/models/glm4_moe.py @@ -39,7 +39,6 @@ from sglang.srt.layers.communicator import ( from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, - get_local_attention_dp_size, is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm @@ -51,9 +50,10 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_deepep_mode, get_moe_a2a_backend from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK -from sglang.srt.layers.moe.utils import should_use_flashinfer_trtllm_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, @@ -76,10 +76,7 @@ from sglang.srt.models.deepseek_v2 import ( DeepseekV2Model, DeepseekV2MoE, ) -from sglang.srt.two_batch_overlap import ( - MaybeTboDeepEPDispatcher, - model_forward_maybe_tbo, -) +from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.utils import ( BumpAllocator, LazyValue, @@ -414,19 +411,15 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn ) - self.topk = ( - TopK( - top_k=config.num_experts_per_tok + self.num_fused_shared_experts, - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - routed_scaling_factor=self.routed_scaling_factor, - ) - if not should_use_flashinfer_trtllm_moe() - else None + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, ) self.experts = get_moe_impl_class()( @@ -441,31 +434,6 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): quant_config=quant_config, routed_scaling_factor=self.routed_scaling_factor, prefix=add_prefix("experts", prefix), - **( - dict(deepep_mode=global_server_args_dict["deepep_mode"]) - if global_server_args_dict["moe_a2a_backend"].is_deepep() - else {} - ), - # Additional args for FusedMoE - **( - dict( - enable_flashinfer_cutlass_moe=True, - ) - if global_server_args_dict["enable_flashinfer_cutlass_moe"] - else {} - ), - **( - dict( - renormalize=config.norm_topk_prob, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, - ) - if should_use_flashinfer_trtllm_moe() - else {} - ), ) self.shared_experts_is_int8 = False @@ -496,7 +464,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): self.top_k = config.num_experts_per_tok - if global_server_args_dict["moe_a2a_backend"].is_deepep(): + if get_moe_a2a_backend().is_deepep(): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( @@ -520,12 +488,12 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, - deepep_mode=global_server_args_dict["deepep_mode"], + deepep_mode=get_deepep_mode(), async_finish=True, return_recv_hook=True, ) - self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() + self._enable_deepep_moe = get_moe_a2a_backend().is_deepep() def forward_normal_dual_stream( self, @@ -542,10 +510,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) kwargs = {"hidden_states": hidden_states} - if self.topk is not None: - kwargs["topk_output"] = self.topk(hidden_states, router_logits) - else: - kwargs["router_logits"] = router_logits + kwargs["topk_output"] = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(**kwargs) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor @@ -588,10 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) kwargs = {"hidden_states": hidden_states} - if self.topk is not None: - kwargs["topk_output"] = self.topk(hidden_states, router_logits) - else: - kwargs["router_logits"] = router_logits + kwargs["topk_output"] = self.topk(hidden_states, router_logits) final_hidden_states = self.experts(**kwargs) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here @@ -761,8 +723,6 @@ class Glm4MoeModel(DeepseekV2Model): ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.dp_size = get_local_attention_dp_size() - class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): @@ -789,7 +749,6 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], ) self.logits_processor = LogitsProcessor(config) - self.dp_size = get_local_attention_dp_size() self._routed_experts_weights_of_layer = LazyValue( lambda: { @@ -953,7 +912,7 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index 140b6e135..576cb3490 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -8,19 +8,11 @@ from transformers.models.glm4v_moe.configuration_glm4v_moe import Glm4vMoeConfig from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, - tensor_model_parallel_all_reduce, ) from sglang.srt.hf_transformers_utils import get_processor -from sglang.srt.layers.dp_attention import ( - get_attention_tp_rank, - get_attention_tp_size, - get_local_attention_dp_size, -) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead @@ -49,7 +41,6 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): config.moe_layer_freq = 1 self.config = config self.tp_size = get_tensor_model_parallel_world_size() - self.dp_size = get_local_attention_dp_size() self.quant_config = quant_config self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") self.num_fused_shared_experts = ( @@ -232,7 +223,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/gpt_oss.py b/python/sglang/srt/models/gpt_oss.py index b5057fb3e..93c4bda49 100644 --- a/python/sglang/srt/models/gpt_oss.py +++ b/python/sglang/srt/models/gpt_oss.py @@ -40,7 +40,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, - get_local_attention_dp_size, is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm @@ -50,9 +49,10 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK -from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4 from sglang.srt.layers.radix_attention import RadixAttention @@ -110,16 +110,13 @@ class GptOssSparseMoeBlock(nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.layer_id = layer_id self.activation = config.hidden_act - self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702) - self.swiglu_limit = config.swiglu_limit + self.gemm1_alpha = getattr(config, "hidden_act_alpha", 1.702) + self.gemm1_clamp_limit = config.swiglu_limit - if global_server_args_dict["enable_flashinfer_mxfp4_moe"]: - self.topk = None - else: - self.topk = TopK( - top_k=config.num_experts_per_tok, - renormalize=True, - ) + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=True, + ) self.top_k = config.num_experts_per_tok experts_type = get_moe_impl_class() @@ -129,11 +126,9 @@ class GptOssSparseMoeBlock(nn.Module): quant_config.get_name() if quant_config is not None else None ) extra_kwargs = { - "enable_flashinfer_cutlass_moe": global_server_args_dict[ - "enable_flashinfer_cutlass_moe" - ], # for moe gate_up_proj and down_proj and their bias loading - "use_weight_loader_fused": quant_config_name != "mxfp4", + "use_weight_loader_fused": quant_config_name + != "mxfp4" } self.experts = experts_type( num_experts=config.num_local_experts @@ -144,15 +139,10 @@ class GptOssSparseMoeBlock(nn.Module): intermediate_size=config.intermediate_size, quant_config=quant_config, activation=self.activation, - activation_alpha=self.activation_alpha, - swiglu_limit=self.swiglu_limit, + gemm1_alpha=self.gemm1_alpha, + gemm1_clamp_limit=self.gemm1_clamp_limit, with_bias=True, prefix=add_prefix("experts", prefix), - **( - dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]]) - if global_server_args_dict["moe_a2a_backend"].is_deepep() - else {} - ), **extra_kwargs, ) @@ -171,7 +161,7 @@ class GptOssSparseMoeBlock(nn.Module): forward_batch: Optional[ForwardBatch] = None, should_allreduce_fusion: bool = False, ) -> torch.Tensor: - if not global_server_args_dict["moe_a2a_backend"].is_deepep(): + if not get_moe_a2a_backend().is_deepep(): return self.forward_normal(hidden_states, should_allreduce_fusion) else: raise Exception("forward_deepep branch not implemented yet") @@ -189,17 +179,10 @@ class GptOssSparseMoeBlock(nn.Module): should_allreduce_fusion: bool = False, ) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (num_tokens, n_experts) router_logits, _ = self.router(hidden_states) - - kwargs = {"hidden_states": hidden_states} - if self.topk is not None: - kwargs["topk_output"] = self.topk(hidden_states, router_logits) - else: - kwargs["topk_output"] = (self.top_k, router_logits) - final_hidden_states = self.experts(**kwargs) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if self.tp_size > 1 and not should_allreduce_fusion: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) @@ -436,7 +419,6 @@ class GptOssDecoderLayer(nn.Module): self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() - self.local_dp_size = get_local_attention_dp_size() # GptOss all layers are sparse and have no nextn now self.is_layer_sparse = True @@ -1060,7 +1042,7 @@ class GptOssForCausalLM(nn.Module): ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ] - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused( + expert_params_mapping = FusedMoE.make_expert_params_mapping_fused( ckpt_gate_up_proj_name="gate_up_proj", ckpt_down_proj_name="down_proj", ckpt_gate_up_proj_bias_name="gate_up_proj_bias", diff --git a/python/sglang/srt/models/granitemoe.py b/python/sglang/srt/models/granitemoe.py index 2da7d857f..d65b9ec06 100644 --- a/python/sglang/srt/models/granitemoe.py +++ b/python/sglang/srt/models/granitemoe.py @@ -76,7 +76,6 @@ class GraniteMoeMoE(nn.Module): params_dtype=params_dtype, reduce_results=True, quant_config=quant_config, - tp_size=tp_size, prefix=f"{prefix}.experts", ) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 36c5a40dc..254d46d7b 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -135,7 +135,6 @@ class Grok1MoE(nn.Module): intermediate_size=intermediate_size, params_dtype=params_dtype, quant_config=quant_config, - tp_size=tp_size, activation="gelu", **kwargs, ) diff --git a/python/sglang/srt/models/interns1.py b/python/sglang/srt/models/interns1.py index 75f2cb775..d72deca41 100644 --- a/python/sglang/srt/models/interns1.py +++ b/python/sglang/srt/models/interns1.py @@ -6,6 +6,7 @@ from transformers import PretrainedConfig from sglang.srt.distributed import parallel_state from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, @@ -254,7 +255,7 @@ class InternS1ForConditionalGeneration(nn.Module): ] expert_params_mapping = [] if "Qwen3MoeForCausalLM" in self.config.text_config.architectures: - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py index db093dd08..94470cc0a 100644 --- a/python/sglang/srt/models/internvl.py +++ b/python/sglang/srt/models/internvl.py @@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo from sglang.srt.distributed import parallel_state from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention -from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, @@ -616,7 +616,7 @@ class InternVLChatModel(nn.Module): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index cf851bd1e..e05d96527 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, - get_local_attention_dp_size, is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm @@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module): rope_theta = config.rope_theta rope_scaling = config.rope_scaling max_position_embeddings = config.max_position_embeddings - self.local_dp_size = get_local_attention_dp_size() self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index 1156c3e47..821dfa98a 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, is_cuda diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 5b8609bdc..c5f04a4fc 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import add_prefix, make_layers @@ -104,7 +103,6 @@ class MixtralMoE(nn.Module): intermediate_size=intermediate_size, params_dtype=params_dtype, quant_config=quant_config, - tp_size=tp_size, prefix=add_prefix("experts", prefix), ) diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index e2db2dceb..a74a2968d 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -89,7 +89,6 @@ class OlmoeMoE(nn.Module): intermediate_size=intermediate_size, reduce_results=True, quant_config=quant_config, - tp_size=tp_size, layer_id=layer_id, prefix=add_prefix("experts", prefix), ) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 81cd97c0e..a3427e068 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -17,8 +17,6 @@ """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" import logging -from dataclasses import dataclass -from enum import Enum, auto from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch @@ -31,10 +29,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from sglang.srt.eplb.expert_distribution import ( - ExpertDistributionRecorder, - get_global_expert_distribution_recorder, -) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.communicator import ( @@ -45,7 +40,6 @@ from sglang.srt.layers.communicator import ( from sglang.srt.layers.dp_attention import ( get_attention_tp_rank, get_attention_tp_size, - get_local_attention_dp_size, is_dp_attention_enabled, ) from sglang.srt.layers.layernorm import RMSNorm @@ -55,8 +49,8 @@ from sglang.srt.layers.linear import ( ReplicatedLinear, RowParallelLinear, ) -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput -from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -149,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module): intermediate_size=config.moe_intermediate_size, quant_config=quant_config, prefix=add_prefix("experts", prefix), - # Additional args for FusedMoE - **( - dict( - enable_flashinfer_cutlass_moe=True, - ) - if global_server_args_dict["enable_flashinfer_cutlass_moe"] - else {} - ), ) self.gate = ReplicatedLinear( @@ -340,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module): self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() - self.local_dp_size = get_local_attention_dp_size() # Qwen2MoE all layers are sparse and have no nextn now self.is_layer_sparse = True diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index c17402863..fcb45b947 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -28,50 +28,35 @@ from sglang.srt.distributed import ( get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - parallel_state, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo -from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes -from sglang.srt.layers.dp_attention import ( - get_attention_tp_rank, - get_attention_tp_size, - get_local_attention_dp_size, -) +from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( - MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.utils import get_layer_id -from sglang.srt.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode -from sglang.srt.model_executor.forward_batch_info import ( - ForwardBatch, - ForwardMode, - PPProxyTensors, -) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP from sglang.srt.models.qwen2_moe import Qwen2MoeModel -from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty Qwen3MoeConfig = None @@ -112,19 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): intermediate_size=config.moe_intermediate_size, quant_config=quant_config, prefix=add_prefix("experts", prefix), - **( - dict(deepep_mode=global_server_args_dict["deepep_mode"]) - if global_server_args_dict["moe_a2a_backend"].is_deepep() - else {} - ), - # Additional args for FusedMoE - **( - dict( - enable_flashinfer_cutlass_moe=True, - ) - if global_server_args_dict["enable_flashinfer_cutlass_moe"] - else {} - ), ) self.gate = ReplicatedLinear( @@ -135,7 +107,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): prefix=add_prefix("gate", prefix), ) - if global_server_args_dict["moe_a2a_backend"].is_deepep(): + if get_moe_a2a_backend().is_deepep(): # TODO: we will support tp < ep in the future self.ep_size = get_moe_expert_parallel_world_size() self.num_experts = ( @@ -150,7 +122,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module): use_reduce_scatter: bool = False, ) -> torch.Tensor: - if not global_server_args_dict["moe_a2a_backend"].is_deepep(): + if not get_moe_a2a_backend().is_deepep(): return self.forward_normal(hidden_states, use_reduce_scatter) else: return self.forward_deepep(hidden_states, forward_batch) @@ -491,7 +463,6 @@ class Qwen3MoeDecoderLayer(nn.Module): self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() - self.local_dp_size = get_local_attention_dp_size() # Qwen3MoE all layers are sparse and have no nextn now self.is_layer_sparse = True @@ -778,7 +749,7 @@ class Qwen3MoeForCausalLM(nn.Module): ("gate_up_proj", "up_proj", 1), ] - expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", diff --git a/python/sglang/srt/models/step3_vl.py b/python/sglang/srt/models/step3_vl.py index 64bb2183c..a93bf69e7 100644 --- a/python/sglang/srt/models/step3_vl.py +++ b/python/sglang/srt/models/step3_vl.py @@ -38,6 +38,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe import get_moe_a2a_backend from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.topk import TopK @@ -150,7 +151,7 @@ class Step3TextMoEMLP(nn.Module): prefix=add_prefix("gate", prefix), ) - if global_server_args_dict["moe_a2a_backend"].is_deepep(): + if get_moe_a2a_backend().is_deepep(): raise NotImplementedError("DeepEP MoE is not supported yet in Step3 model.") def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 0ea9ed950..6067acec6 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -33,7 +33,9 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.fused_moe_triton import fused_moe +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -121,6 +123,7 @@ class XverseMoE(nn.Module): ] ) self.pack_params() + self.moe_runner_config = MoeRunnerConfig(inplace=True) self.router = ReplicatedLinear( config.hidden_size, @@ -129,6 +132,10 @@ class XverseMoE(nn.Module): quant_config=None, prefix=add_prefix("router", prefix), ) + self.topk = TopK( + top_k=self.top_k, + renormalize=getattr(self.config, "norm_topk_prob", False), + ) if config.num_shared_experts is not None: intermediate_size = config.intermediate_size * config.num_shared_experts @@ -167,14 +174,13 @@ class XverseMoE(nn.Module): shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.router(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = fused_moe( hidden_states, self.w1, self.w2, - router_logits, - self.top_k, - renormalize=getattr(self.config, "norm_topk_prob", False), - inplace=True, + topk_output, + self.moe_runner_config, ) if self.config.num_shared_experts is not None: diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 0d6c794e6..c4f664872 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -37,6 +37,7 @@ from sglang.srt.utils import ( is_hip, is_port_available, is_remote_url, + is_triton_kernels_available, is_valid_ipv6_address, nullable_str, ) @@ -175,9 +176,15 @@ class ServerArgs: # Expert parallelism ep_size: int = 1 - moe_a2a_backend: Optional[Literal["deepep"]] = None - enable_flashinfer_cutlass_moe: bool = False - enable_flashinfer_trtllm_moe: bool = False + moe_a2a_backend: Literal["none", "deepep"] = "none" + moe_runner_backend: Literal[ + "auto", + "triton", + "triton_kernel", + "flashinfer_trtllm", + "flashinfer_cutlass", + "flashinfer_mxfp4", + ] = "auto" enable_flashinfer_allreduce_fusion: bool = False deepep_mode: Literal["auto", "normal", "low_latency"] = "auto" ep_num_redundant_experts: int = 0 @@ -250,8 +257,6 @@ class ServerArgs: disable_chunked_prefix_cache: bool = False disable_fast_image_processor: bool = False enable_return_hidden_states: bool = False - enable_triton_kernel_moe: bool = False - enable_flashinfer_mxfp4_moe: bool = False scheduler_recv_interval: int = 1 # Debug tensor dumps @@ -282,6 +287,9 @@ class ServerArgs: # Deprecated arguments enable_ep_moe: bool = False enable_deepep_moe: bool = False + enable_flashinfer_cutlass_moe: bool = False + enable_flashinfer_trtllm_moe: bool = False + enable_triton_kernel_moe: bool = False def __post_init__(self): # Check deprecated arguments @@ -298,6 +306,21 @@ class ServerArgs: print_deprecated_warning( "NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead." ) + if self.enable_triton_kernel_moe: + self.moe_runner_backend = "triton_kernel" + print_deprecated_warning( + "NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead." + ) + if self.enable_flashinfer_cutlass_moe: + self.moe_runner_backend = "flashinfer_cutlass" + print_deprecated_warning( + "NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead." + ) + if self.enable_flashinfer_trtllm_moe: + self.moe_runner_backend = "flashinfer_trtllm" + print_deprecated_warning( + "NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead." + ) # Set missing default values if self.tokenizer_path is None: @@ -517,7 +540,7 @@ class ServerArgs: ), "Please enable dp attention when setting enable_dp_lm_head. " # MoE kernel - if self.enable_flashinfer_cutlass_moe: + if self.moe_runner_backend == "flashinfer_cutlass": assert ( self.quantization == "modelopt_fp4" ), "modelopt_fp4 quantization is required for Flashinfer MOE" @@ -527,7 +550,7 @@ class ServerArgs: self.tp_size, ], "The expert parallel size must be 1 or the same as the tensor parallel size" - if self.enable_flashinfer_trtllm_moe: + if self.moe_runner_backend == "flashinfer_trtllm": if not self.disable_shared_experts_fusion: self.disable_shared_experts_fusion = True logger.warning( @@ -556,7 +579,7 @@ class ServerArgs: self.ep_dispatch_algorithm = "static" if self.enable_eplb: - assert self.ep_size > 1 or self.moe_a2a_backend is not None + assert self.ep_size > 1 if self.enable_expert_distribution_metrics and ( self.expert_distribution_recorder_mode is None @@ -1446,19 +1469,22 @@ class ServerArgs: parser.add_argument( "--moe-a2a-backend", type=str, - choices=["deepep"], + choices=["none", "deepep"], default=ServerArgs.moe_a2a_backend, help="Choose the backend for MoE A2A.", ) parser.add_argument( - "--enable-flashinfer-cutlass-moe", - action="store_true", - help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", - ) - parser.add_argument( - "--enable-flashinfer-trtllm-moe", - action="store_true", - help="Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP", + "--moe-runner-backend", + type=str, + choices=[ + "auto", + "triton", + "triton_kernel", + "flashinfer_trtllm", + "flashinfer_cutlass", + ], + default=ServerArgs.moe_runner_backend, + help="Choose the runner backend for MoE.", ) parser.add_argument( "--enable-flashinfer-allreduce-fusion", @@ -1825,11 +1851,6 @@ class ServerArgs: action="store_true", help="Enable returning hidden states with responses.", ) - parser.add_argument( - "--enable-triton-kernel-moe", - action="store_true", - help="Use triton moe grouped gemm kernel.", - ) parser.add_argument( "--enable-flashinfer-mxfp4-moe", action="store_true", @@ -1965,6 +1986,21 @@ class ServerArgs: action="store_true", help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.", ) + parser.add_argument( + "--enable-flashinfer-cutlass-moe", + action="store_true", + help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP", + ) + parser.add_argument( + "--enable-flashinfer-trtllm-moe", + action="store_true", + help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP", + ) + parser.add_argument( + "--enable-triton-kernel-moe", + action="store_true", + help="(Deprecated) Use triton moe grouped gemm kernel.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -2143,18 +2179,21 @@ class ServerArgs: ) if is_sm100_supported() and is_mxfp4_quant_format: - self.enable_flashinfer_mxfp4_moe = True - self.enable_triton_kernel_moe = False + self.moe_runner_backend = "flashinfer_mxfp4" logger.warning( "Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel." ) else: - if self.enable_triton_kernel_moe: + if self.moe_runner_backend == "triton_kernel": assert ( self.ep_size == 1 ), "Triton kernel MoE is only supported when ep_size == 1" - if not self.enable_triton_kernel_moe and self.ep_size == 1: - self.enable_triton_kernel_moe = True + if ( + self.moe_runner_backend == "auto" + and self.ep_size == 1 + and is_triton_kernels_available() + ): + self.moe_runner_backend = "triton_kernel" logger.warning( "Detected GPT-OSS model, enabling triton_kernels MOE kernel." ) diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 223ff0cbe..e02bc1fd2 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -14,8 +14,13 @@ from sglang.srt.layers.communicator import ( CommunicateSummableTensorPairFn, ScatterMode, ) +from sglang.srt.layers.moe import ( + get_deepep_mode, + get_moe_a2a_backend, + get_tbo_token_distribution_threshold, + is_tbo_enabled, +) from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher -from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ( @@ -83,7 +88,7 @@ def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool: vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens) left_sum = sum(extend_lens[:vanilla_split_seq_index]) overall_sum = sum(extend_lens) - threshold = global_server_args_dict["tbo_token_distribution_threshold"] + threshold = get_tbo_token_distribution_threshold() assert threshold <= 0.5, f"{threshold=}" return left_sum < overall_sum * threshold or left_sum > overall_sum * ( 1 - threshold @@ -299,7 +304,7 @@ class TboCudaGraphRunnerPlugin: self._tbo_children_num_token_non_padded = torch.zeros((2,), dtype=torch.int32) def capture_one_batch_size(self, batch: ForwardBatch, num_tokens: int): - if not global_server_args_dict["enable_two_batch_overlap"]: + if not is_tbo_enabled(): return token_num_per_seq = get_token_num_per_seq( forward_mode=batch.forward_mode, spec_info=batch.spec_info @@ -353,10 +358,12 @@ class TboDPAttentionPreparer: def prepare_all_gather( self, local_batch: ScheduleBatch, - deepep_mode: DeepEPMode, - enable_deepep_moe: bool, - enable_two_batch_overlap: bool, ): + + deepep_mode = get_deepep_mode() + enable_deepep_moe = get_moe_a2a_backend().is_deepep() + enable_two_batch_overlap = is_tbo_enabled() + self.enable_two_batch_overlap = enable_two_batch_overlap if local_batch is not None: @@ -384,7 +391,7 @@ class TboDPAttentionPreparer: and not local_batch.forward_mode.is_target_verify() ) and enable_deepep_moe - and (resolved_deepep_mode == DeepEPMode.LOW_LATENCY) + and (resolved_deepep_mode.is_low_latency()) ) else: self.local_tbo_split_seq_index = 0 @@ -657,6 +664,7 @@ class TboForwardBatchPreparer: "req_to_token_pool", "token_to_kv_pool", "can_run_dp_cuda_graph", + "dp_padding_mode", "global_forward_mode", "spec_algorithm", "capture_hidden_mode", @@ -701,7 +709,6 @@ class TboForwardBatchPreparer: tbo_children=None, global_num_tokens_gpu=None, global_num_tokens_cpu=None, - dp_padding_mode=None, global_dp_buffer_len=global_dp_buffer_len, global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_cpu=None, @@ -955,9 +962,7 @@ def _model_forward_tbo_merge_outputs(output_a, output_b): class MaybeTboDeepEPDispatcher: def __init__(self, **kwargs): - num_inner_dispatchers = ( - 2 if global_server_args_dict["enable_two_batch_overlap"] else 1 - ) + num_inner_dispatchers = 2 if is_tbo_enabled() else 1 self._inners = [ DeepEPDispatcher(**kwargs) for _ in range(num_inner_dispatchers) ] diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index f8be9c8d8..d15ef2a93 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -2413,7 +2413,7 @@ def require_mlp_tp_gather(server_args): return True elif not server_args.enable_dp_lm_head: return True - elif server_args.moe_a2a_backend is None: + elif server_args.moe_a2a_backend == "none": return True else: return ( @@ -2429,7 +2429,7 @@ def require_attn_tp_gather(server_args): Check if the input of attention is scattered. """ assert server_args.moe_dense_tp_size in [1, None] - if server_args.moe_a2a_backend is not None or server_args.moe_dense_tp_size == 1: + if server_args.moe_a2a_backend != "none" or server_args.moe_dense_tp_size == 1: if server_args.enable_dp_attention: return server_args.dp_size < server_args.tp_size else: diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index fd2c95608..45271e116 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -6,7 +6,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.quantization.fp8_kernel import ( per_tensor_quant_mla_fp8, per_token_group_quant_fp8, @@ -498,11 +498,13 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase): score = torch.randn((M, E), dtype=dtype) with torch.inference_mode(): + ref_out = torch_w8a8_block_fp8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size + ) topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, - renormalize=False, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) out = fused_moe( a, @@ -514,9 +516,6 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase): w2_scale=w2_s, block_shape=block_size, ) - ref_out = torch_w8a8_block_fp8_moe( - a, w1, w2, w1_s, w2_s, score, topk, block_size - ) self.assertTrue( torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py index 2f92c5435..670f2e0f8 100644 --- a/python/sglang/test/test_block_fp8_ep.py +++ b/python/sglang/test/test_block_fp8_ep.py @@ -12,7 +12,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( run_moe_ep_preproess, silu_and_mul_triton_kernel, ) -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.test.test_utils import CustomTestCase @@ -22,35 +22,26 @@ def ep_moe( w1: torch.Tensor, w2: torch.Tensor, router_logits: torch.Tensor, - top_k: int, - renormalize: bool, + topk_config: TopKConfig, # ep config num_experts: int = 256, fp8_dtype: torch.types = torch.float8_e4m3fn, num_experts_per_partition: int = 128, start_expert_id: int = 0, end_expert_id: int = 127, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, w1_scale_inv: Optional[torch.Tensor] = None, w2_scale_inv: Optional[torch.Tensor] = None, block_shape: Optional[List[int]] = None, ): use_blockwise_fp8 = block_shape is not None - topk_weights, topk_ids, _ = select_experts( + top_k = topk_config.top_k + topk_output = select_experts( hidden_states=hidden_states, router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - # correction_bias=correction_bias, #skip this in test - custom_routing_function=custom_routing_function, + topk_config=topk_config, ) + topk_weights, topk_ids, _ = topk_output reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(topk_ids, num_experts) @@ -294,14 +285,18 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase): start_id = cur_rank * num_experts_per_partition end_id = start_id + num_experts_per_partition - 1 + topk_config = TopKConfig( + top_k=topk, + renormalize=False, + ) + with torch.inference_mode(): out = ep_moe( hidden_states=a, w1=w1, w2=w2, router_logits=score, - top_k=topk, - renormalize=False, + topk_config=topk_config, use_fp8_w8a8=True, w1_scale_inv=w1_s, w2_scale_inv=w2_s, @@ -316,8 +311,7 @@ class TestW8A8BlockFP8EPMoE(CustomTestCase): w1=w1_ref, w2=w2_ref, router_logits=score, - top_k=topk, - renormalize=False, + topk_config=topk_config, use_fp8_w8a8=False, w1_scale_inv=None, w2_scale_inv=None, diff --git a/python/sglang/test/test_cutlass_w4a8_moe.py b/python/sglang/test/test_cutlass_w4a8_moe.py index c823bf1f7..622941f00 100644 --- a/python/sglang/test/test_cutlass_w4a8_moe.py +++ b/python/sglang/test/test_cutlass_w4a8_moe.py @@ -6,7 +6,7 @@ import pytest import torch from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts def pack_int4_values_to_int8(int4_values_interleaved: torch.Tensor) -> torch.Tensor: @@ -100,11 +100,12 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): s_strides2 = c_strides2 score = torch.randn((M, E), dtype=dtype, device=device) - topk_weights, topk_ids, _ = select_experts( + topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) + topk_weights, topk_ids, _ = topk_output expert_map = torch.arange(E, dtype=torch.int32, device=device) expert_map[local_e:] = E diff --git a/python/sglang/test/test_fp4_moe.py b/python/sglang/test/test_fp4_moe.py index bf2308a8f..8f8c8e8a7 100644 --- a/python/sglang/test/test_fp4_moe.py +++ b/python/sglang/test/test_fp4_moe.py @@ -9,7 +9,7 @@ from sgl_kernel import scaled_fp4_quant from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts if torch.cuda.get_device_capability() < (10, 0): pytest.skip( @@ -163,11 +163,12 @@ def check_moe( score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids, _ = select_experts( + topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) + topk_weights, topk_ids, _ = topk_output a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) a2_gs = torch.ones((e,), device="cuda", dtype=torch.float32) diff --git a/test/srt/quant/test_block_int8.py b/test/srt/quant/test_block_int8.py index 58bd7c1e1..f6ceb03d0 100644 --- a/test/srt/quant/test_block_int8.py +++ b/test/srt/quant/test_block_int8.py @@ -5,7 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.test.test_utils import CustomTestCase @@ -175,10 +175,13 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase): topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) with torch.inference_mode(): + ref_out = torch_w8a8_block_int8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size + ) out = fused_moe( a, w1, @@ -189,9 +192,6 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase): w2_scale=w2_s, block_shape=block_size, ) - ref_out = torch_w8a8_block_int8_moe( - a, w1, w2, w1_s, w2_s, score, topk, block_size - ) self.assertTrue( torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) diff --git a/test/srt/quant/test_int8_kernel.py b/test/srt/quant/test_int8_kernel.py index bbadce230..dd75d06af 100644 --- a/test/srt/quant/test_int8_kernel.py +++ b/test/srt/quant/test_int8_kernel.py @@ -5,7 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.test.test_utils import CustomTestCase @@ -118,7 +118,7 @@ class TestW8A8Int8FusedMoE(CustomTestCase): topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) out = fused_moe( a, diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 1a0452c41..9f2cc31b1 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -6,7 +6,7 @@ from tqdm import tqdm from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.utils import is_hip @@ -136,19 +136,7 @@ class TestFusedMOE(CustomTestCase): topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, - ) - - sglang_output = fused_moe( - a, - w1, - w2, - topk_output, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a1_scale, - a2_scale=a2_scale, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) torch_output = self.torch_naive_moe( @@ -162,6 +150,18 @@ class TestFusedMOE(CustomTestCase): a1_scale, a2_scale, ) + + sglang_output = fused_moe( + a, + w1, + w2, + topk_output, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + ) torch.testing.assert_close( sglang_output, torch_output, rtol=rtol, atol=atol ) @@ -174,7 +174,7 @@ class TestFusedMOE(CustomTestCase): topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) triton_output = fused_moe(a, w1, w2, topk_output) diff --git a/test/srt/test_triton_moe_channel_fp8_kernel.py b/test/srt/test_triton_moe_channel_fp8_kernel.py index 577570757..bbe44308f 100644 --- a/test/srt/test_triton_moe_channel_fp8_kernel.py +++ b/test/srt/test_triton_moe_channel_fp8_kernel.py @@ -5,7 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.test.test_utils import CustomTestCase @@ -130,7 +130,7 @@ class TestW8A8FP8FusedMoE(CustomTestCase): topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, + topk_config=TopKConfig(top_k=topk, renormalize=False), ) out = fused_moe( a, diff --git a/test/srt/test_triton_moe_wna16.py b/test/srt/test_triton_moe_wna16.py index 51583c2f2..b447b532f 100644 --- a/test/srt/test_triton_moe_wna16.py +++ b/test/srt/test_triton_moe_wna16.py @@ -5,7 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKConfig, select_experts NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @@ -223,7 +223,7 @@ def test_fused_moe_wn16( topk_output = select_experts( hidden_states=a, router_logits=score, - top_k=topk, + topk_config=TopKConfig(top_k=topk), ) triton_output = fused_moe(