From 15ad6c908670492243cfcb820ca24c40cc9b840d Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sat, 19 Jul 2025 00:51:15 -0700 Subject: [PATCH] [1/N] MoE Refactor: refactor `select_experts` (#7966) --- python/sglang/srt/custom_op.py | 7 +- python/sglang/srt/layers/linear.py | 2 +- python/sglang/srt/layers/moe/ep_moe/layer.py | 87 ++------ .../sglang/srt/layers/moe/fused_moe_native.py | 54 +---- .../layers/moe/fused_moe_triton/fused_moe.py | 45 +--- .../srt/layers/moe/fused_moe_triton/layer.py | 35 +-- python/sglang/srt/layers/moe/topk.py | 176 ++++++++++++++- .../srt/layers/quantization/__init__.py | 32 +-- python/sglang/srt/layers/quantization/awq.py | 39 +--- .../srt/layers/quantization/base_config.py | 21 +- .../srt/layers/quantization/blockwise_int8.py | 35 +-- .../compressed_tensors_moe.py | 92 ++------ python/sglang/srt/layers/quantization/fp8.py | 52 +---- python/sglang/srt/layers/quantization/gptq.py | 35 +-- .../srt/layers/quantization/modelopt_quant.py | 63 +----- .../srt/layers/quantization/moe_wna16.py | 34 +-- .../sglang/srt/layers/quantization/unquant.py | 207 +++++------------- .../srt/layers/quantization/w8a8_fp8.py | 37 +--- .../srt/layers/quantization/w8a8_int8.py | 87 ++------ python/sglang/srt/models/deepseek.py | 15 +- python/sglang/srt/models/deepseek_v2.py | 52 ++--- python/sglang/srt/models/granitemoe.py | 10 +- python/sglang/srt/models/grok.py | 12 +- python/sglang/srt/models/hunyuan.py | 13 +- python/sglang/srt/models/llama4.py | 22 +- python/sglang/srt/models/mixtral.py | 11 +- python/sglang/srt/models/olmoe.py | 13 +- python/sglang/srt/models/phimoe.py | 12 +- python/sglang/srt/models/qwen2_moe.py | 14 +- python/sglang/srt/models/qwen3_moe.py | 31 ++- python/sglang/test/test_block_fp8.py | 11 +- python/sglang/test/test_block_fp8_ep.py | 2 +- python/sglang/test/test_cutlass_w4a8_moe.py | 4 +- python/sglang/test/test_fp4_moe.py | 4 +- test/srt/test_block_int8.py | 11 +- test/srt/test_fused_moe.py | 19 +- test/srt/test_int8_kernel.py | 10 +- .../srt/test_triton_moe_channel_fp8_kernel.py | 10 +- test/srt/test_triton_moe_wna16.py | 11 +- 39 files changed, 556 insertions(+), 871 deletions(-) diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index 5b502a153..8c662b5cc 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -29,15 +29,18 @@ class CustomOp(nn.Module): self._original_forward_method = self._forward_method # NOTE: Temporarily workaround MoE + # The performance of torch.compile on this layer is not always good when bs > 1, + # so we decide to only use torch.compile when bs=1 if "FusedMoE" in self.__class__.__name__: if num_tokens == 1: from sglang.srt.layers.moe.fused_moe_native import ( fused_moe_forward_native, ) - # The performance of torch.compile on this layer is not always good when bs > 1, - # so we decide to only use torch.compile when bs =1 self._forward_method = fused_moe_forward_native + elif "TopK" in self.__class__.__name__: + if num_tokens == 1: + self._forward_method = self.forward_native else: self._forward_method = self.forward_native self.is_torch_compile = True diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 07be9a3c6..9d8ab8632 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -756,7 +756,7 @@ class QKVParallelLinear(ColumnParallelLinear): bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - quant_config: Optional["QuantizationConfig"] = None, + quant_config: Optional[QuantizationConfig] = None, prefix: str = "", tp_rank: Optional[int] = None, tp_size: Optional[int] = None, diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index a839b47fe..77d849f3f 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -1,17 +1,13 @@ import logging -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple -import einops import torch -from torch.nn import Module -from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata -from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.moe.ep_moe.kernels import ( ep_gather, ep_scatter, @@ -28,7 +24,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import ( tma_align_input_scale, ) from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, @@ -162,16 +158,9 @@ class EPMoE(torch.nn.Module): intermediate_size: int, layer_id: int, params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, activation: str = "silu", routed_scaling_factor: Optional[float] = None, use_per_token_if_dynamic: bool = True, @@ -189,24 +178,12 @@ class EPMoE(torch.nn.Module): self.layer_id = layer_id self.num_experts = num_experts assert self.num_experts % self.tp_size == 0 - assert ( - num_fused_shared_experts == 0 - ), "num_fused_shared_experts is not supported in EP" - self.num_fused_shared_experts = num_fused_shared_experts self.num_experts_per_partition, self.expert_map = self.determine_expert_map() self.start_expert_id = self.tp_rank * self.num_experts_per_partition self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1 self.top_k = top_k self.intermediate_size = intermediate_size - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.correction_bias = correction_bias - self.custom_routing_function = custom_routing_function self.activation = activation self.routed_scaling_factor = routed_scaling_factor self.use_per_token_if_dynamic = use_per_token_if_dynamic @@ -311,33 +288,24 @@ class EPMoE(torch.nn.Module): ) return (local_num_experts, expert_map) - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8: - return self.forward_deepgemm(hidden_states, router_logits) + return self.forward_deepgemm(hidden_states, topk_output) else: - return self.forward_normal(hidden_states, router_logits) + return self.forward_normal(hidden_states, topk_output) def forward_deepgemm( - self, hidden_states: torch.Tensor, router_logits: torch.Tensor + self, + hidden_states: torch.Tensor, + topk_output: TopKOutput, ): assert self.quant_method is not None assert self.activation == "silu" hidden_states_shape = hidden_states.shape hidden_states_dtype = hidden_states.dtype hidden_states_device = hidden_states.device - topk_weights, topk_ids = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - custom_routing_function=self.custom_routing_function, - routed_scaling_factor=self.routed_scaling_factor, - ) + + topk_weights, topk_ids, _ = topk_output if not self.use_block_quant: # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm @@ -469,8 +437,10 @@ class EPMoE(torch.nn.Module): ) return output - def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert self.quant_method is not None + topk_weights, topk_ids, _ = topk_output + hidden_states_shape = hidden_states.shape hidden_states_dtype = hidden_states.dtype hidden_states_device = hidden_states.device @@ -481,23 +451,6 @@ class EPMoE(torch.nn.Module): use_per_token_if_dynamic=self.use_per_token_if_dynamic, ) - topk_weights, topk_ids = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=self.use_grouped_topk, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - custom_routing_function=self.custom_routing_function, - routed_scaling_factor=self.routed_scaling_factor, - expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( - layer_id=self.layer_id, - ), - ) - if self.use_w4afp8: local_topk_ids = topk_ids if self.expert_map is not None: @@ -916,16 +869,9 @@ class DeepEPMoE(EPMoE): intermediate_size: int, layer_id: int, params_dtype: Optional[torch.dtype] = None, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", - correction_bias: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, activation: str = "silu", routed_scaling_factor: Optional[float] = None, deepep_mode: DeepEPMode = DeepEPMode.auto, @@ -937,16 +883,9 @@ class DeepEPMoE(EPMoE): intermediate_size=intermediate_size, layer_id=layer_id, params_dtype=params_dtype, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - topk_group=topk_group, quant_config=quant_config, tp_size=tp_size, prefix=prefix, - correction_bias=correction_bias, - custom_routing_function=custom_routing_function, activation=activation, routed_scaling_factor=routed_scaling_factor, ) diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 25645ad00..61eacd78c 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -9,21 +9,14 @@ import torch from torch.nn import functional as F from sglang.srt.layers.activation import GeluAndMul, SiluAndMul -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKOutput def fused_moe_forward_native( layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -34,20 +27,7 @@ def fused_moe_forward_native( if apply_router_weight_on_input: raise NotImplementedError() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - torch_native=True, - ) + topk_weights, topk_ids, _ = topk_output w13_weights = layer.w13_weight[topk_ids] w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2) @@ -67,15 +47,8 @@ def fused_moe_forward_native( def moe_forward_native( layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -86,20 +59,7 @@ def moe_forward_native( if apply_router_weight_on_input: raise NotImplementedError() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - torch_native=True, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, _ = topk_output # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589 len_experts = layer.num_experts diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index baf8f5c87..a39d6d5d3 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -6,13 +6,13 @@ import functools import json import logging import os -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import triton import triton.language as tl -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, scaled_fp8_quant, @@ -1328,8 +1328,7 @@ def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_output: TopKOutput, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, @@ -1348,7 +1347,7 @@ def fused_experts( no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ): - + topk_weights, topk_ids, _ = topk_output if inplace: assert not no_combine, "no combine + inplace makes no sense" torch.ops.sglang.inplace_fused_experts( @@ -1732,17 +1731,10 @@ def fused_moe( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - renormalize: bool, + topk_output: TopKOutput, inplace: bool = False, activation: str = "silu", apply_router_weight_on_input: bool = False, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, @@ -1766,16 +1758,9 @@ def fused_moe( - hidden_states (torch.Tensor): The input tensor to the MoE layer. - w1 (torch.Tensor): The first set of expert weights. - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - topk_output (TopKOutput): The top-k output of the experts. - inplace (bool): If True, perform the operation in-place. Defaults to False. - - num_expert_group: Optional[int]: additional parameter for grouped_topk - - topk_group: Optional[int]: additional parameter for grouped_topk - - use_grouped_topk: If True, use grouped_topk instead of fused_topk - note: Deepseek V2/V3/R1 series models use grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner @@ -1799,28 +1784,12 @@ def fused_moe( Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ - # Check constraints. - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - - topk_weights, topk_ids = select_experts( - hidden_states=hidden_states, - router_logits=gating_output, - use_grouped_topk=use_grouped_topk, - top_k=topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( hidden_states, w1, w2, - topk_weights, - topk_ids, + topk_output, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 41ae6274b..0c3cb0422 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -2,7 +2,7 @@ import logging from enum import Enum -from typing import Callable, List, Optional, Tuple +from typing import List, Optional, Tuple import torch @@ -11,6 +11,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) +from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -59,22 +60,15 @@ class FusedMoE(torch.nn.Module): def __init__( self, num_experts: int, - top_k: int, hidden_size: int, intermediate_size: int, + top_k: Optional[int] = None, layer_id: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = False, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - topk_group: Optional[int] = None, quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", apply_router_weight_on_input: bool = False, use_presharded_weights: bool = False, @@ -89,6 +83,7 @@ class FusedMoE(torch.nn.Module): if params_dtype is None: params_dtype = torch.get_default_dtype() + self.top_k = top_k self.hidden_size = hidden_size self.tp_size = ( tp_size if tp_size is not None else get_tensor_model_parallel_world_size() @@ -126,19 +121,9 @@ class FusedMoE(torch.nn.Module): self.ep_rank = 0 self.local_num_experts = num_experts self.routed_scaling_factor = routed_scaling_factor - self.top_k = top_k assert intermediate_size % self.tp_size == 0 self.intermediate_size_per_partition = intermediate_size // self.tp_size self.reduce_results = reduce_results - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.num_fused_shared_experts = num_fused_shared_experts - self.topk_group = topk_group - self.custom_routing_function = custom_routing_function - self.correction_bias = correction_bias self.activation = activation self.apply_router_weight_on_input = apply_router_weight_on_input self.use_presharded_weights = use_presharded_weights @@ -562,22 +547,14 @@ class FusedMoE(torch.nn.Module): ) return - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): + def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput): assert self.quant_method is not None # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - renormalize=self.renormalize, - use_grouped_topk=self.use_grouped_topk, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - custom_routing_function=self.custom_routing_function, - correction_bias=self.correction_bias, + topk_output=topk_output, activation=self.activation, apply_router_weight_on_input=self.apply_router_weight_on_input, routed_scaling_factor=self.routed_scaling_factor, diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 40fc0b61f..bb3cf6515 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -12,12 +12,15 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations + import math -from typing import Callable, Optional +from typing import TYPE_CHECKING, Callable, NamedTuple, Optional import torch import torch.nn.functional as F +from sglang.srt.custom_op import CustomOp from sglang.srt.eplb import expert_location_dispatch from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder from sglang.srt.eplb.expert_location_dispatch import ( @@ -52,6 +55,168 @@ if _use_aiter: except ImportError: raise ImportError("aiter is required when SGLANG_USE_AITER is set to True") +if _is_npu: + import torch_npu + + +class TopKOutput(NamedTuple): + topk_weights: torch.Tensor + topk_ids: torch.Tensor + router_logits: torch.Tensor + + +class TopK(CustomOp): + + # TODO(ch-wan): support triton_kernels + + def __init__( + self, + top_k: int, + *, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + renormalize: bool = True, + num_fused_shared_experts: int = 0, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + correction_bias: Optional[torch.Tensor] = None, + routed_scaling_factor: Optional[float] = None, + ): + # NOTE: scoring_func is not used for now, but we keep it for future use + # see https://github.com/sgl-project/sglang/pull/4505 for more details + super().__init__() + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + self.top_k = top_k + self.use_grouped_topk = use_grouped_topk + self.renormalize = renormalize + self.topk_group = topk_group + self.num_expert_group = num_expert_group + self.num_fused_shared_experts = num_fused_shared_experts + self.custom_routing_function = custom_routing_function + self.correction_bias = correction_bias + self.routed_scaling_factor = routed_scaling_factor + + def forward_native( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + torch_native = True + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + torch_native=torch_native, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + + def forward_cuda( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + torch_native = False + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + torch_native=torch_native, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + + def forward_cpu( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + + def forward_npu( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + *, + num_token_non_padded: Optional[torch.Tensor] = None, + expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + ) -> TopKOutput: + global_num_experts = router_logits.shape[-1] + + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern + if global_num_experts == 256: + return torch_npu.npu_moe_gating_top_k( + router_logits, + k=self.top_k, + bias=self.correction_bias, + k_group=self.topk_group, + group_count=self.num_expert_group, + group_select_mode=1, + renorm=0, + norm_type=1, + routed_scaling_factor=1, + eps=float(1e-20), + ) + else: + torch_native = True + return select_experts( + hidden_states=hidden_states, + router_logits=router_logits, + top_k=self.top_k, + use_grouped_topk=self.use_grouped_topk, + renormalize=self.renormalize, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + num_fused_shared_experts=self.num_fused_shared_experts, + custom_routing_function=self.custom_routing_function, + correction_bias=self.correction_bias, + torch_native=torch_native, + routed_scaling_factor=self.routed_scaling_factor, + num_token_non_padded=num_token_non_padded, + expert_location_dispatch_info=expert_location_dispatch_info, + ) + def fused_topk_torch_native( hidden_states: torch.Tensor, @@ -436,8 +601,9 @@ def select_experts( hidden_states: torch.Tensor, router_logits: torch.Tensor, top_k: int, - use_grouped_topk: bool, - renormalize: bool, + *, + use_grouped_topk: bool = False, + renormalize: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, num_fused_shared_experts: int = 0, @@ -447,7 +613,7 @@ def select_experts( routed_scaling_factor: Optional[float] = None, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, -): +) -> TopKOutput: router_logits, correction_bias = ( expert_location_dispatch.transform_select_experts_inputs( router_logits=router_logits, @@ -522,4 +688,4 @@ def select_experts( get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids) - return topk_weights, topk_ids + return TopKOutput(topk_weights, topk_ids, router_logits) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index d51186465..496cbc8f5 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -1,7 +1,9 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py +from __future__ import annotations + import builtins import inspect -from typing import Callable, Dict, Optional, Type, Union +from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union import torch @@ -65,6 +67,9 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + # Base quantization methods that don't depend on vllm BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "fp8": Fp8Config, @@ -186,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -208,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): "self": self, "layer": layer, "x": x, - "router_logits": router_logits, - "top_k": top_k, - "renormalize": renormalize, - "use_grouped_topk": use_grouped_topk, - "topk_group": topk_group, - "num_expert_group": num_expert_group, - "custom_routing_function": custom_routing_function, + "topk_output": topk_output, } - if correction_bias is not None: - if not has_correction_bias: - raise ValueError( - "Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`" - ) - kwargs["e_score_correction_bias"] = correction_bias return original_apply(**kwargs) setattr(class_obj, "apply", new_apply) diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index c20beb2ff..0f66b954c 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging import warnings -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch @@ -33,6 +33,9 @@ from sglang.srt.layers.quantization.scalar_type import scalar_types from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import replace_parameter +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + try: from vllm import _custom_ops as ops @@ -737,45 +740,19 @@ class AWQMoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - correction_bias: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, + topk_output: TopKOutput, + *, activation: str = "silu", - routed_scaling_factor: Optional[float] = None, + **kwargs, ) -> torch.Tensor: - # Delay the import to avoid circular dependency - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - assert ( - scoring_func == "softmax" - ), "Only softmax score func is supported for now." # The input must currently be float16 orig_dtype = x.dtype x = x.half() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, router_logits = topk_output return fused_marlin_moe( x, diff --git a/python/sglang/srt/layers/quantization/base_config.py b/python/sglang/srt/layers/quantization/base_config.py index 607151671..bf24c3701 100644 --- a/python/sglang/srt/layers/quantization/base_config.py +++ b/python/sglang/srt/layers/quantization/base_config.py @@ -1,12 +1,16 @@ # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py +from __future__ import annotations import inspect from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type import torch from torch import nn +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + class QuantizeMethodBase(ABC): """Base class for different quantized methods.""" @@ -88,19 +92,22 @@ class FusedMoEMethodBase(QuantizeMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ): - raise NotImplementedError() + raise NotImplementedError @abstractmethod def apply( self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, + topk_output: TopKOutput, + *, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + inplace: bool = True, + no_combine: bool = False, + routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - raise NotImplementedError() + raise NotImplementedError class QuantizationConfig(ABC): diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index a1da999b3..62dc45ad9 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from torch.nn import Module @@ -21,6 +21,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + ACTIVATION_SCHEMES = ["static", "dynamic"] logger = logging.getLogger(__name__) @@ -344,15 +347,8 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -360,30 +356,13 @@ class BlockInt8MoEMethod(FusedMoEMethodBase): routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) # Expert fusion with INT8 quantization return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index b471184d2..39e5f9e25 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,15 +1,17 @@ # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import enum import logging from enum import Enum -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, List, Optional import torch from compressed_tensors import CompressionFormat from compressed_tensors.quantization import QuantizationStrategy +from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( @@ -20,6 +22,12 @@ from sglang.srt.layers.quantization.utils import ( ) from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) + _is_cuda = is_cuda() _is_npu = is_npu() _is_cpu_amx_available = cpu_has_amx_support() @@ -51,7 +59,7 @@ __all__ = [ ] -class CompressedTensorsMoEMethod: +class CompressedTensorsMoEMethod(FusedMoEMethodBase): def __new__(cls, *args, **kwargs): if cls is CompressedTensorsMoEMethod: return super().__new__(cls) @@ -59,7 +67,7 @@ class CompressedTensorsMoEMethod: @staticmethod def get_moe_method( - quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501 + quant_config: CompressedTensorsConfig, ) -> "CompressedTensorsMoEMethod": # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -82,9 +90,7 @@ class CompressedTensorsMoEMethod: class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 - ): + def __init__(self, quant_config: CompressedTensorsConfig): self.quant_config = quant_config self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") self.input_quant = self.quant_config.target_scheme_map["Linear"].get( @@ -270,47 +276,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, - apply_router_weight_on_input: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, use_fp8_w8a8=True, @@ -327,9 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): - def __init__( - self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 - ): + def __init__(self, quant_config: CompressedTensorsConfig): self.quant_config = quant_config # TODO: @dsikka: refactor this to use schemes as other kernels # are supported + check if the layer is being ignored. @@ -628,43 +606,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", - routed_scaling_factor: Optional[float] = None, + **kwargs, ) -> torch.Tensor: - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - if expert_map is not None: - raise NotImplementedError( - "Expert Parallelism is not supported for " "fused Marlin MoE method." - ) - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, router_logits = topk_output return torch.ops.vllm.fused_marlin_moe( x, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 7275ea430..23daa5d26 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch import torch.nn.functional as F @@ -78,6 +78,7 @@ from sglang.srt.utils import ( ) if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config _is_hip = is_hip() @@ -971,15 +972,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -987,26 +981,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( apply_router_weight_on_input, topk_weights, x ) @@ -1032,8 +1011,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ret = self.maybe_apply_hip_fused_experts( layer, x, - topk_weights, - topk_ids, + topk_output, activation, no_combine, ) @@ -1048,6 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): ): from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8 + topk_weights, topk_ids, _ = topk_output return cutlass_fused_experts_fp8( x, layer.w13_weight.transpose(1, 2), @@ -1076,8 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace and not no_combine, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -1101,11 +1079,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + topk_output: TopKOutput, activation: str = "silu", no_combine: bool = False, ) -> Optional[torch.Tensor]: + topk_weights, topk_ids, _ = topk_output if _use_hip_int4: # TODO: add triton kernel and add check _use_aiter assert not no_combine, f"{no_combine=} is not supported." @@ -1397,14 +1375,8 @@ class Fp8EPMoEMethod(Fp8MoEMethod): def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + hidden_states: torch.Tensor, + topk_output: TopKOutput, ) -> torch.Tensor: raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index af56c3be7..4f2eba4e3 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -3,7 +3,7 @@ from __future__ import annotations import logging from dataclasses import dataclass from fractions import Fraction -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union import torch @@ -43,6 +43,9 @@ from sglang.srt.layers.quantization.utils import ( unpack_cols, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + try: from vllm import _custom_ops as ops except ImportError: @@ -1057,42 +1060,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - e_score_correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", + **kwargs, ) -> torch.Tensor: # Delay the import to avoid circular dependency - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - assert ( - scoring_func == "softmax" - ), "Only softmax score func is supported for now." # The input must currently be float16 orig_dtype = x.dtype x = x.half() - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - correction_bias=e_score_correction_bias, - ) + topk_weights, topk_ids, router_logits = topk_output return fused_marlin_moe( x, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index 5263f3b92..73de5b0d1 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from torch.nn.parameter import Parameter @@ -31,6 +31,9 @@ from sglang.srt.layers.quantization.utils import ( from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.utils import is_cuda, next_power_of_2 +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + if is_cuda(): from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant @@ -402,15 +405,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -418,29 +414,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase): routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, use_fp8_w8a8=True, @@ -961,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -982,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." - from sglang.srt.layers.moe.topk import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) if self.enable_flashinfer_moe: assert ( @@ -1004,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ), "apply_router_weight_on_input is not supported for Flashinfer" # TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision # and fp4 quantized weights loaded from the checkpoint + topk_weights, topk_ids, _ = topk_output output = flashinfer_cutlass_fused_moe( x, topk_ids.to(torch.int), @@ -1029,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + topk_weights, topk_ids, _ = topk_output return cutlass_moe_fp4( a=x, a1_gscale=layer.w13_input_scale_quant, diff --git a/python/sglang/srt/layers/quantization/moe_wna16.py b/python/sglang/srt/layers/quantization/moe_wna16.py index f83b9bb1f..fbbf11066 100644 --- a/python/sglang/srt/layers/quantization/moe_wna16.py +++ b/python/sglang/srt/layers/quantization/moe_wna16.py @@ -2,8 +2,9 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional +import numpy as np import torch from sglang.srt.distributed import get_tensor_model_parallel_rank @@ -20,6 +21,9 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + def get_weight_perm(num_bits: int): perm_list: List[int] = [] @@ -348,15 +352,8 @@ class MoeWNA16Method(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -365,22 +362,8 @@ class MoeWNA16Method(FusedMoEMethodBase): ) -> torch.Tensor: # avoid circular import from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp @@ -389,8 +372,7 @@ class MoeWNA16Method(FusedMoEMethodBase): x, layer.w13_qweight, layer.w2_qweight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, apply_router_weight_on_input=apply_router_weight_on_input, use_int4_w4a16=weight_bits == 4, diff --git a/python/sglang/srt/layers/quantization/unquant.py b/python/sglang/srt/layers/quantization/unquant.py index 06afcb70b..fa4cbf582 100644 --- a/python/sglang/srt/layers/quantization/unquant.py +++ b/python/sglang/srt/layers/quantization/unquant.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import importlib -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional import torch import torch.nn.functional as F @@ -21,6 +23,9 @@ from sglang.srt.utils import ( use_intel_amx_backend, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None @@ -125,25 +130,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): super().__init__() self.use_triton_kernels = use_triton_kernels - from sglang.srt.layers.moe.fused_moe_native import moe_forward_native - - if torch.cuda.is_available(): - from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - - if has_triton_kernels: - from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import ( - triton_kernel_moe_forward, - ) - else: - triton_kernel_moe_forward = None - else: - fused_experts = None # type: ignore - triton_kernel_moe_forward = None - - self.moe_forward_native = moe_forward_native - self.fused_experts = fused_experts - self.triton_kernel_moe_forward = triton_kernel_moe_forward - def create_weights( self, layer: torch.nn.Module, @@ -201,34 +187,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - return self.forward( x=x, layer=layer, - router_logits=router_logits, - top_k=top_k, - renormalize=renormalize, - use_grouped_topk=use_grouped_topk, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, + topk_output=topk_output, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, inplace=inplace, @@ -240,15 +210,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -257,33 +220,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) -> torch.Tensor: if self.use_triton_kernels: - return self.triton_kernel_moe_forward( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - gating_output=router_logits, - topk=top_k, - renormalize=renormalize, - ) + # TODO(ch-wan): re-enable the Triton kernel + raise NotImplementedError("The Triton kernel is temporarily disabled.") + # return triton_kernel_moe_forward( + # hidden_states=x, + # w1=layer.w13_weight, + # w2=layer.w2_weight, + # gating_output=router_logits, + # topk=top_k, + # renormalize=renormalize, + # ) else: - from sglang.srt.layers.moe.topk import select_experts - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) - if _use_aiter: assert not no_combine, "unsupported" + topk_weights, topk_ids, _ = topk_output if apply_router_weight_on_input: assert ( topk_weights.dim() == 2 @@ -296,7 +246,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): topk_weights = torch.ones_like( topk_weights, dtype=torch.float32 ) # topk_weights must be FP32 (float32) - return fused_moe( x, layer.w13_weight, @@ -310,12 +259,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ), ) else: - return self.fused_experts( + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_experts, + ) + + return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace and not no_combine, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -327,15 +279,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -344,30 +289,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ) -> torch.Tensor: assert activation == "silu", f"activation = {activation} is not supported." - if use_intel_amx_backend(layer): + if use_intel_amx_backend(layer) and not apply_router_weight_on_input: + from sglang.srt.layers.moe.topk import apply_topk_weights_cpu - from sglang.srt.layers.moe.topk import ( - apply_topk_weights_cpu, - select_experts, - ) - - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( apply_router_weight_on_input, topk_weights, x ) - return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, @@ -385,61 +313,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): True, # is_vnni ) else: - return self.moe_forward_native( + from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + + return moe_forward_native( layer, x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, + topk_output, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) def forward_npu( self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: - return self.moe_forward_native( + from sglang.srt.layers.moe.fused_moe_native import moe_forward_native + + return moe_forward_native( layer, x, - use_grouped_topk, - top_k, - router_logits, - renormalize, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - inplace, - no_combine, - routed_scaling_factor, + topk_output, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + inplace=inplace, + no_combine=no_combine, + routed_scaling_factor=routed_scaling_factor, ) def forward_tpu(self, *args, **kwargs) -> torch.Tensor: @@ -508,13 +417,7 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp): def apply( self, layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, + hidden_states: torch.Tensor, + topk_output: TopKOutput, ) -> torch.Tensor: raise NotImplementedError diff --git a/python/sglang/srt/layers/quantization/w8a8_fp8.py b/python/sglang/srt/layers/quantization/w8a8_fp8.py index 871a4534c..e486fef0b 100644 --- a/python/sglang/srt/layers/quantization/w8a8_fp8.py +++ b/python/sglang/srt/layers/quantization/w8a8_fp8.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter @@ -25,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( ) from sglang.srt.utils import set_weight_attrs +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + _is_fp8_fnuz = is_fp8_fnuz() @@ -266,45 +269,23 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", + apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) return fused_experts( x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, + apply_router_weight_on_input=apply_router_weight_on_input, activation=activation, use_fp8_w8a8=True, per_channel_quant=True, diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index 19cf49c9b..22e8b108f 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -3,7 +3,7 @@ from __future__ import annotations import importlib import sys from types import MappingProxyType -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast import torch from torch.nn.parameter import Parameter @@ -37,6 +37,9 @@ from sglang.srt.utils import ( use_intel_amx_backend, ) +if TYPE_CHECKING: + from sglang.srt.layers.moe.topk import TopKOutput + _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() @@ -239,7 +242,7 @@ class W8A8Int8Config(QuantizationConfig): layer: torch.nn.Module, prefix: str, ) -> Optional[QuantizeMethodBase]: - from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if _is_npu: @@ -469,15 +472,8 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): self, layer: torch.nn.Module, x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - num_fused_shared_experts: int = 0, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, + topk_output: TopKOutput, + *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, @@ -485,26 +481,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts - from sglang.srt.layers.moe.topk import select_experts - - # Expert selection - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - routed_scaling_factor=routed_scaling_factor, - ) if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu + topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( apply_router_weight_on_input, topk_weights, x ) @@ -529,8 +510,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase): x, layer.w13_weight, layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, + topk_output=topk_output, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, @@ -907,7 +887,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): layer: torch.nn.Module, num_experts: int, hidden_size: int, - intermediate_size: List[int], + intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: @@ -984,52 +964,11 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): self, layer, x, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - num_fused_shared_experts, - custom_routing_function, - correction_bias, - activation, - apply_router_weight_on_input, - routed_scaling_factor, + topk_output: TopKOutput, **kwargs, ) -> torch.Tensor: - from sglang.srt.layers.moe.topk import select_experts - global_num_experts = router_logits.shape[-1] - # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: - topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( - router_logits, - k=top_k, - bias=correction_bias, - k_group=topk_group, - group_count=num_expert_group, - group_select_mode=1, - renorm=0, - norm_type=1, - routed_scaling_factor=1, - eps=float(1e-20), - ) - else: - topk_weights, topk_ids = select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - num_fused_shared_experts=num_fused_shared_experts, - custom_routing_function=custom_routing_function, - correction_bias=correction_bias, - torch_native=True, - routed_scaling_factor=routed_scaling_factor, - ) + topk_weights, topk_ids, _ = topk_output topk_ids = topk_ids.to(torch.int32) topk_weights = topk_weights.to(x.dtype) return npu_fused_experts( @@ -1040,5 +979,5 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase): w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, - top_k=top_k, + top_k=topk_ids.shape[1], ) diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 95bfe001a..f2f0d0344 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -37,6 +37,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import fused_moe +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module): f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {self.n_routed_experts}." ) - + self.topk = TopK( + top_k=self.top_k, + renormalize=config.norm_topk_prob, + ) self.experts = nn.ModuleList( [ DeepseekMLP( @@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module): shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = fused_moe.fused_moe( hidden_states, - self.w1, - self.w2, - router_logits, - self.top_k, - renormalize=self.config.norm_topk_prob, + w1=self.w1, + w2=self.w2, + topk_output=topk_output, inplace=True, ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 0da956b01..9ec5db926 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -58,7 +58,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( @@ -303,6 +303,17 @@ class DeepseekV2MoE(nn.Module): config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn ) + self.topk = TopK( + top_k=config.num_experts_per_tok + self.num_fused_shared_experts, + renormalize=config.norm_topk_prob, + use_grouped_topk=True, + num_expert_group=config.n_group, + num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=config.topk_group, + correction_bias=self.gate.e_score_correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + ) + self.experts = get_moe_impl_class()( num_experts=config.n_routed_experts + self.num_fused_shared_experts @@ -311,13 +322,7 @@ class DeepseekV2MoE(nn.Module): hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, layer_id=self.layer_id, - renormalize=config.norm_topk_prob, quant_config=quant_config, - use_grouped_topk=True, - num_expert_group=config.n_group, - num_fused_shared_experts=self.num_fused_shared_experts, - topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, routed_scaling_factor=self.routed_scaling_factor, prefix=add_prefix("experts", prefix), **( @@ -451,8 +456,9 @@ class DeepseekV2MoE(nn.Module): with torch.cuda.stream(self.alt_stream): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits + hidden_states=hidden_states, topk_output=topk_output ) if not _is_cuda: final_hidden_states *= self.routed_scaling_factor @@ -473,8 +479,9 @@ class DeepseekV2MoE(nn.Module): shared_output = self._forward_shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits + hidden_states=hidden_states, topk_output=topk_output ) if not _is_cuda and not _use_aiter: # fused in biased_grouped_topk so we can skip here @@ -490,8 +497,9 @@ class DeepseekV2MoE(nn.Module): ) -> torch.Tensor: # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) fused_experts_out = self.experts( - hidden_states=hidden_states, router_logits=router_logits + hidden_states=hidden_states, topk_output=topk_output ) assert use_intel_amx_backend( @@ -549,17 +557,9 @@ class DeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits = self.gate(hidden_states) shared_output = self._forward_shared_experts(hidden_states) - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, + topk_weights, topk_idx, _ = self.topk( + hidden_states, + router_logits, num_token_non_padded=forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, @@ -649,17 +649,9 @@ class DeepseekV2MoE(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.topk_weights_local, state.topk_idx_local = select_experts( + state.topk_weights_local, state.topk_idx_local, _ = self.topk( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=True, - renormalize=self.renormalize, - topk_group=self.topk_group, - num_expert_group=self.num_expert_group, - num_fused_shared_experts=self.num_fused_shared_experts, - correction_bias=self.correction_bias, - routed_scaling_factor=self.routed_scaling_factor, num_token_non_padded=state.forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, diff --git a/python/sglang/srt/models/granitemoe.py b/python/sglang/srt/models/granitemoe.py index b4a9c17af..1e6109209 100644 --- a/python/sglang/srt/models/granitemoe.py +++ b/python/sglang/srt/models/granitemoe.py @@ -15,6 +15,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -60,6 +61,11 @@ class GraniteMoeMoE(nn.Module): prefix=f"{prefix}.gate", ) + self.topk = TopK( + top_k=top_k, + renormalize=True, + ) + self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, @@ -67,7 +73,6 @@ class GraniteMoeMoE(nn.Module): intermediate_size=intermediate_size, params_dtype=params_dtype, reduce_results=True, - renormalize=True, quant_config=quant_config, tp_size=tp_size, prefix=f"{prefix}.experts", @@ -78,7 +83,8 @@ class GraniteMoeMoE(nn.Module): orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index a8cde8e09..4a46bf197 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.router import fused_moe_router_shim +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -108,6 +109,12 @@ class Grok1MoE(nn.Module): fused_moe_router_shim, self.router_logit_softcapping ) + self.topk = TopK( + top_k=top_k, + renormalize=False, + custom_routing_function=custom_routing_function, + ) + kwargs = {} if global_server_args_dict["enable_ep_moe"]: MoEImpl = EPMoE @@ -124,17 +131,16 @@ class Grok1MoE(nn.Module): hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - renormalize=False, quant_config=quant_config, tp_size=tp_size, - custom_routing_function=custom_routing_function, activation="gelu", **kwargs, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # need to assert self.gate.quant_method is unquantized - return self.experts(hidden_states, self.gate.weight) + topk_output = self.topk(hidden_states, self.gate.weight) + return self.experts(hidden_states, topk_output) class Grok1Attention(nn.Module): diff --git a/python/sglang/srt/models/hunyuan.py b/python/sglang/srt/models/hunyuan.py index f23ccc0a8..58e95bbb1 100644 --- a/python/sglang/srt/models/hunyuan.py +++ b/python/sglang/srt/models/hunyuan.py @@ -40,6 +40,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -152,13 +153,16 @@ class HunYuanSparseMoeBlock(nn.Module): else config.moe_intermediate_size[layer_id] ) + self.topk = TopK( + top_k=top_k, + renormalize=True if top_k > 1 else False, + ) + self.experts = FusedMoE( num_experts=config.num_experts, - top_k=top_k, hidden_size=config.hidden_size, intermediate_size=intermediate_size, reduce_results=False, - renormalize=True if top_k > 1 else False, quant_config=quant_config, ) @@ -195,9 +199,8 @@ class HunYuanSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output if self.tp_size > 1: diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 1bb6fcc12..cf0b20800 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -40,6 +40,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -103,14 +104,17 @@ class Llama4MoE(nn.Module): prefix=add_prefix("router", prefix), ) + self.topk = TopK( + top_k=self.top_k, + renormalize=False, + custom_routing_function=Llama4MoE.custom_routing_function, + ) + self.experts = FusedMoE( num_experts=config.num_local_experts, - top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, - custom_routing_function=Llama4MoE.custom_routing_function, intermediate_size=intermediate_size_moe, reduce_results=False, - renormalize=False, quant_config=quant_config, apply_router_weight_on_input=True, prefix=add_prefix("experts", prefix), @@ -147,10 +151,8 @@ class Llama4MoE(nn.Module): # router_scores: [num_tokens, num_experts] router_logits, _ = self.router(hidden_states) shared_out = self.shared_expert(hidden_states) - routed_out = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - ) + topk_output = self.topk(hidden_states, router_logits) + routed_out = self.experts(hidden_states, topk_output) return shared_out, routed_out def _forward_core_shared_routed_overlap(self, hidden_states): @@ -163,10 +165,8 @@ class Llama4MoE(nn.Module): with self.device_module.stream(alt_stream): # router_scores: [num_tokens, num_experts] router_logits, _ = self.router(hidden_states) - routed_out = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - ) + topk_output = self.topk(hidden_states, router_logits) + routed_out = self.experts(hidden_states, topk_output) self.device_module.current_stream().wait_stream(alt_stream) return shared_out, routed_out diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 90a12f12f..b09fc2f24 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -37,6 +37,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -86,6 +87,12 @@ class MixtralMoE(nn.Module): quant_config=None, prefix=add_prefix("gate", prefix), ) + + self.topk = TopK( + top_k=top_k, + renormalize=True, + ) + MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE self.experts = MoEImpl( num_experts=num_experts, @@ -93,7 +100,6 @@ class MixtralMoE(nn.Module): hidden_size=hidden_size, intermediate_size=intermediate_size, params_dtype=params_dtype, - renormalize=True, quant_config=quant_config, tp_size=tp_size, prefix=add_prefix("experts", prefix), @@ -105,7 +111,8 @@ class MixtralMoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index 612120fe9..ce53f2b01 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -32,6 +32,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -76,13 +77,16 @@ class OlmoeMoE(nn.Module): prefix=add_prefix("gate", prefix), ) + self.topk = TopK( + top_k=top_k, + renormalize=False, + ) + self.experts = FusedMoE( num_experts=num_experts, - top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, reduce_results=True, - renormalize=False, quant_config=quant_config, tp_size=tp_size, prefix=add_prefix("experts", prefix), @@ -94,9 +98,8 @@ class OlmoeMoE(nn.Module): hidden_states = hidden_states.view(-1, self.hidden_size) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/phimoe.py b/python/sglang/srt/models/phimoe.py index 22ee023c8..865b94f51 100644 --- a/python/sglang/srt/models/phimoe.py +++ b/python/sglang/srt/models/phimoe.py @@ -13,6 +13,7 @@ from sglang.srt.layers.linear import ( ) from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention @@ -200,15 +201,19 @@ class PhiMoE(nn.Module): quant_config=None, ) + self.topk = TopK( + top_k=top_k, + renormalize=False, + custom_routing_function=phimoe_routing_function, + ) + self.experts = FusedMoE( num_experts=num_experts, top_k=top_k, hidden_size=hidden_size, intermediate_size=intermediate_size, reduce_results=True, - renormalize=False, quant_config=quant_config, - custom_routing_function=phimoe_routing_function, prefix=add_prefix("experts", prefix), ) @@ -219,7 +224,8 @@ class PhiMoE(nn.Module): orig_shape = hidden_states.shape hidden_states = hidden_states.view(-1, self.hidden_size) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts(hidden_states, router_logits) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) return final_hidden_states.view(orig_shape) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index fe2636ab7..e033424cf 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -61,6 +61,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -134,13 +135,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module): f"the number of experts {config.num_experts}." ) + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + ) + self.experts = get_moe_impl_class()( layer_id=self.layer_id, - num_experts=config.num_experts, top_k=config.num_experts_per_tok, + num_experts=config.num_experts, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=add_prefix("experts", prefix), # Additional args for FusedMoE @@ -189,9 +194,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if shared_output is not None: final_hidden_states = final_hidden_states + shared_output final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 75d3b475c..c75a38499 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -56,8 +56,7 @@ from sglang.srt.layers.linear import ( from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher -from sglang.srt.layers.moe.fused_moe_triton import FusedMoE -from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope @@ -102,6 +101,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module): f"the number of experts {config.num_experts}." ) + self.topk = TopK( + top_k=config.num_experts_per_tok, + renormalize=config.norm_topk_prob, + use_grouped_topk=False, + ) + self.experts = get_moe_impl_class()( num_experts=config.num_experts + global_server_args_dict["ep_num_redundant_experts"], @@ -109,7 +114,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): layer_id=layer_id, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, - renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=add_prefix("experts", prefix), **( @@ -143,7 +147,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module): config.num_experts + global_server_args_dict["ep_num_redundant_experts"] ) self.top_k = config.num_experts_per_tok - self.renormalize = config.norm_topk_prob self.deepep_dispatcher = MaybeTboDeepEPDispatcher( group=parallel_state.get_tp_group().device_group, @@ -180,9 +183,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, router_logits=router_logits - ) + topk_output = self.topk(hidden_states, router_logits) + final_hidden_states = self.experts(hidden_states, topk_output) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) @@ -195,13 +197,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): if is_non_idle_and_non_empty(forward_mode, hidden_states): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - - topk_weights, topk_idx = select_experts( - hidden_states=hidden_states, - router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=self.renormalize, + topk_weights, topk_idx, _ = self.topk( + hidden_states, + router_logits, num_token_non_padded=forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, @@ -267,12 +265,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module): with get_global_expert_distribution_recorder().with_current_layer( self.layer_id ): - state.topk_weights_local, state.topk_idx_local = select_experts( + state.topk_weights_local, state.topk_idx_local, _ = self.topk( hidden_states=hidden_states, router_logits=router_logits, - top_k=self.top_k, - use_grouped_topk=False, - renormalize=self.renormalize, num_token_non_padded=state.forward_batch.num_token_non_padded, expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( layer_id=self.layer_id, diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index a5a338632..fd2c95608 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -6,6 +6,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import ( per_tensor_quant_mla_fp8, per_token_group_quant_fp8, @@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase): score = torch.randn((M, E), dtype=dtype) with torch.inference_mode(): + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + renormalize=False, + ) out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=True, w1_scale=w1_s, w2_scale=w2_s, diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py index bd735edbd..2f92c5435 100644 --- a/python/sglang/test/test_block_fp8_ep.py +++ b/python/sglang/test/test_block_fp8_ep.py @@ -40,7 +40,7 @@ def ep_moe( block_shape: Optional[List[int]] = None, ): use_blockwise_fp8 = block_shape is not None - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=hidden_states, router_logits=router_logits, top_k=top_k, diff --git a/python/sglang/test/test_cutlass_w4a8_moe.py b/python/sglang/test/test_cutlass_w4a8_moe.py index acf8a27b9..c823bf1f7 100644 --- a/python/sglang/test/test_cutlass_w4a8_moe.py +++ b/python/sglang/test/test_cutlass_w4a8_moe.py @@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype): s_strides2 = c_strides2 score = torch.randn((M, E), dtype=dtype, device=device) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=a, router_logits=score, top_k=topk, - use_grouped_topk=False, - renormalize=False, ) expert_map = torch.arange(E, dtype=torch.int32, device=device) expert_map[local_e:] = E diff --git a/python/sglang/test/test_fp4_moe.py b/python/sglang/test/test_fp4_moe.py index 7e3de278c..30b1fe9db 100644 --- a/python/sglang/test/test_fp4_moe.py +++ b/python/sglang/test/test_fp4_moe.py @@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph( score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = select_experts( + topk_weights, topk_ids, _ = select_experts( hidden_states=a, router_logits=score, top_k=topk, - use_grouped_topk=False, - renormalize=False, ) a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32) diff --git a/test/srt/test_block_int8.py b/test/srt/test_block_int8.py index 2b8b841f0..58bd7c1e1 100644 --- a/test/srt/test_block_int8.py +++ b/test/srt/test_block_int8.py @@ -5,6 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.test.test_utils import CustomTestCase @@ -171,14 +172,18 @@ class TestW8A8BlockINT8FusedMoE(CustomTestCase): score = torch.randn((M, E), dtype=dtype) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + with torch.inference_mode(): out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_int8_w8a8=True, w1_scale=w1_s, w2_scale=w2_s, diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index d1c2735d1..1a0452c41 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -6,6 +6,7 @@ from tqdm import tqdm from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.utils import is_hip @@ -132,13 +133,17 @@ class TestFusedMOE(CustomTestCase): input_scale=a2_scale, ) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + sglang_output = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=True, w1_scale=w1_scale, w2_scale=w2_scale, @@ -166,7 +171,13 @@ class TestFusedMOE(CustomTestCase): w2 = self.create_random_cuda_tensor((e, k, n), dtype) score = self.create_random_cuda_tensor((m, e), dtype) - triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + + triton_output = fused_moe(a, w1, w2, topk_output) torch_output = self.torch_naive_moe(a, w1, w2, score, topk) torch.testing.assert_close( triton_output, torch_output, rtol=rtol, atol=atol diff --git a/test/srt/test_int8_kernel.py b/test/srt/test_int8_kernel.py index 3e9f7a7dd..bbadce230 100644 --- a/test/srt/test_int8_kernel.py +++ b/test/srt/test_int8_kernel.py @@ -5,6 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.test.test_utils import CustomTestCase @@ -114,13 +115,16 @@ class TestW8A8Int8FusedMoE(CustomTestCase): with torch.inference_mode(): ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=False, # Not using fp8 use_int8_w8a16=False, # Not using int8-w8a16 use_int8_w8a8=True, # Using int8-w8a8 diff --git a/test/srt/test_triton_moe_channel_fp8_kernel.py b/test/srt/test_triton_moe_channel_fp8_kernel.py index 89b5af650..577570757 100644 --- a/test/srt/test_triton_moe_channel_fp8_kernel.py +++ b/test/srt/test_triton_moe_channel_fp8_kernel.py @@ -5,6 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant from sglang.test.test_utils import CustomTestCase @@ -126,13 +127,16 @@ class TestW8A8FP8FusedMoE(CustomTestCase): with torch.inference_mode(): ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) out = fused_moe( a, w1, w2, - score, - topk, - renormalize=False, + topk_output, use_fp8_w8a8=True, # using fp8 use_int8_w8a16=False, use_int8_w8a8=False, diff --git a/test/srt/test_triton_moe_wna16.py b/test/srt/test_triton_moe_wna16.py index 2613586a8..51583c2f2 100644 --- a/test/srt/test_triton_moe_wna16.py +++ b/test/srt/test_triton_moe_wna16.py @@ -5,6 +5,7 @@ import torch from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.topk import select_experts NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @@ -219,13 +220,17 @@ def test_fused_moe_wn16( if has_zp: w_qzeros[expert_id] = qzeros + topk_output = select_experts( + hidden_states=a, + router_logits=score, + top_k=topk, + ) + triton_output = fused_moe( a, w1_qweight, w2_qweight, - score, - topk, - renormalize=False, + topk_output, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, w1_scale=w1_scales,