Files
sglang/python/sglang/srt/layers/moe/topk.py

898 lines
32 KiB
Python

# Copyright 2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
from enum import Enum, auto
from typing import (
TYPE_CHECKING,
Callable,
NamedTuple,
Optional,
Protocol,
TypeGuard,
runtime_checkable,
)
import torch
import torch.nn.functional as F
from sglang.srt.custom_op import CustomOp
from sglang.srt.eplb import expert_location_dispatch
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location_dispatch import (
ExpertLocationDispatchInfo,
topk_ids_logical_to_physical,
)
from sglang.srt.layers.moe import (
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
is_npu,
)
if TYPE_CHECKING:
from sglang.srt.layers.quantization import QuantizationConfig
try:
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
except ImportError:
pass
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import moe_fused_gate
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
if _use_aiter:
try:
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
except ImportError:
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
if _is_npu:
import torch_npu
# -------------------------------- TopKConfig ---------------------------------------
@dataclass
class TopKConfig:
top_k: int
use_grouped_topk: bool = False
topk_group: Optional[int] = None
num_expert_group: Optional[int] = None
renormalize: bool = True
num_fused_shared_experts: int = 0
custom_routing_function: Optional[Callable] = None
correction_bias: Optional[torch.Tensor] = None
torch_native: bool = False
routed_scaling_factor: Optional[float] = None
apply_routed_scaling_factor_on_output: bool = False
output_format: Optional[TopKOutputFormat] = None
# -------------------------------- TopKOutput ---------------------------------------
class TopKOutputChecker:
@staticmethod
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
return topk_output.format.is_standard()
@staticmethod
def format_is_triton_kernel(
topk_output: TopKOutput,
) -> TypeGuard[TritonKernelTopKOutput]:
return topk_output.format.is_triton_kernel()
@staticmethod
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
return topk_output.format.is_bypassed()
class TopKOutputFormat(Enum):
STANDARD = auto()
TRITON_KERNEL = auto()
BYPASSED = auto()
def is_standard(self) -> bool:
return self == TopKOutputFormat.STANDARD
def is_triton_kernel(self) -> bool:
return self == TopKOutputFormat.TRITON_KERNEL
def is_bypassed(self) -> bool:
return self == TopKOutputFormat.BYPASSED
@runtime_checkable
class TopKOutput(Protocol):
"""Protocol for top-k outputs in different formats."""
@property
def format(self) -> TopKOutputFormat:
"""The format of the output."""
...
class StandardTopKOutput(NamedTuple):
"""Standard top-k output format."""
topk_weights: torch.Tensor
topk_ids: torch.Tensor
router_logits: torch.Tensor
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.STANDARD
class TritonKernelTopKOutput(NamedTuple):
"""Triton kernel top-k output format."""
routing_data: RoutingData
gather_indx: GatherIndx
scatter_indx: ScatterIndx
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.TRITON_KERNEL
class BypassedTopKOutput(NamedTuple):
"""Bypassed top-k output format."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
topk_config: TopKConfig
num_token_non_padded: Optional[torch.Tensor] = None
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
@property
def format(self) -> TopKOutputFormat:
return TopKOutputFormat.BYPASSED
# -------------------------------- TopK ---------------------------------------
class TopK(CustomOp):
def __init__(
self,
top_k: int,
*,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
renormalize: bool = True,
num_fused_shared_experts: int = 0,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
correction_bias: Optional[torch.Tensor] = None,
quant_config: Optional[QuantizationConfig] = None,
routed_scaling_factor: Optional[float] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
output_format: Optional[TopKOutputFormat] = None,
):
# NOTE: scoring_func is not used for now, but we keep it for future use
# see https://github.com/sgl-project/sglang/pull/4505 for more details
super().__init__()
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
if (
quant_config is not None
and quant_config.get_name() == "modelopt_fp4"
and should_use_flashinfer_trtllm_moe()
):
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
correction_bias = correction_bias.to(torch.bfloat16)
self.topk_config = TopKConfig(
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
routed_scaling_factor=routed_scaling_factor,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
output_format=output_format,
)
def forward_native(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
if self.topk_config.output_format is not None:
output_format = self.topk_config.output_format
elif get_moe_runner_backend().is_triton_kernel():
output_format = TopKOutputFormat.TRITON_KERNEL
elif (
should_use_flashinfer_trtllm_moe()
or get_moe_runner_backend().is_flashinfer_mxfp4()
):
output_format = TopKOutputFormat.BYPASSED
else:
output_format = TopKOutputFormat.STANDARD
if output_format == TopKOutputFormat.TRITON_KERNEL:
# renormalize=True is equivalent to sm_first=False
routing_data, gather_idx, scatter_idx = routing(
router_logits,
self.topk_config.top_k,
sm_first=not self.topk_config.renormalize,
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
elif output_format == TopKOutputFormat.BYPASSED:
return BypassedTopKOutput(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else:
self.topk_config.torch_native = False
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_cpu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def forward_npu(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> TopKOutput:
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
router_logits = router_logits.to(torch.float32)
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=self.topk_config.top_k,
bias=self.topk_config.correction_bias.to(torch.float32),
k_group=self.topk_config.topk_group,
group_count=self.topk_config.num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20),
)
if self.topk_config.renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if self.topk_config.num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if expert_location_dispatch_info is not None:
topk_ids = topk_ids_logical_to_physical(
topk_ids, expert_location_dispatch_info
)
get_global_expert_distribution_recorder().on_select_experts(
topk_ids=topk_ids
)
return StandardTopKOutput(topk_weights, topk_ids, _)
else:
self.topk_config.torch_native = True
return select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
topk_config=self.topk_config,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
def empty_topk_output(self, device: torch.device) -> TopKOutput:
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
# ------------------------------- TopK implementation -------------------------------------
def fused_topk_torch_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
correction_bias: torch.Tensor = None,
):
if correction_bias is not None:
n_routed_experts = gating_output.shape[-1]
scores = gating_output.softmax(dim=-1)
scores_for_choice = scores.view(
-1, n_routed_experts
) + correction_bias.unsqueeze(0)
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
topk_weights = scores.gather(1, topk_ids)
else:
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
correction_bias: torch.Tensor = None,
):
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
hidden_states=hidden_states,
gating_output=gating_output,
topk=topk,
renormalize=renormalize,
)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
if not need_apply:
return inputs, topk_weights
# TODO: fuse below processing in fused_experts_cpu kernel
inputs = inputs * topk_weights.to(inputs.dtype)
topk_weights = torch.ones_like(
topk_weights, dtype=torch.float32
) # clear topk_weights as already applied
return inputs, topk_weights
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_softmax(
topk_weights,
topk_ids,
gating_output,
renormalize,
)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
# This is used by the Deepseek V2/V3/R1 series models
@torch.compile(dynamic=True, backend=get_compiler_backend())
def grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
# NPU compiler limitation
if _is_npu and scores.dtype == torch.bfloat16:
scores = scores.to(torch.float16)
num_token = scores.shape[0]
num_experts = scores.shape[1]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
# TODO: NPU can't support directly evaluating a comparison for now
topk_weights, topk_ids = torch.topk(
tmp_scores,
k=topk,
dim=-1,
sorted=(True if num_fused_shared_experts > 0 else False),
)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert not apply_routed_scaling_factor_on_output
assert expert_location_dispatch_info is None
return torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
def biased_grouped_topk_impl(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = gating_output.sigmoid()
num_token = scores.shape[0]
num_experts = scores.shape[1]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores_for_choice.masked_fill(
~score_mask.bool(), float("-inf")
) # [n, e]
# TODO: NPU can't support directly evaluating a comparison for now
_, topk_ids = torch.topk(
tmp_scores,
k=topk,
dim=-1,
sorted=(True if num_fused_shared_experts > 0 else False),
)
topk_weights = scores.gather(1, topk_ids)
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
)
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
if apply_routed_scaling_factor_on_output:
topk_weights *= routed_scaling_factor
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
def _mask_topk_ids_padded_region(
topk_ids: torch.Tensor,
num_token_non_padded: Optional[torch.Tensor] = None,
):
if num_token_non_padded is None:
return
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
topk_ids[indices >= num_token_non_padded, :] = -1
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _biased_grouped_topk_postprocess(
topk_ids, expert_location_dispatch_info, num_token_non_padded
):
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_ids
def biased_grouped_topk_gpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert (
routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk"
# TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.
if (
_is_cuda
and gating_output.shape[1] // num_expert_group
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and is_power_of_two(correction_bias.shape[0])
):
topk_weights, topk_ids = moe_fused_gate(
gating_output.to(dtype=torch.float32),
correction_bias,
num_expert_group,
topk_group,
topk,
num_fused_shared_experts,
routed_scaling_factor,
apply_routed_scaling_factor_on_output,
)
# TODO merge into kernel
if (expert_location_dispatch_info is not None) or (
num_token_non_padded is not None
):
topk_ids = _biased_grouped_topk_postprocess(
topk_ids, expert_location_dispatch_info, num_token_non_padded
)
return topk_weights, topk_ids
elif _use_aiter:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
token = gating_output.shape[0]
device = gating_output.device
assert (
hidden_states.shape[0] == gating_output.shape[0]
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
aiter_biased_grouped_topk(
gating_output.to(dtype=torch.float32),
correction_bias,
topk_weights,
topk_ids,
num_expert_group,
topk_group,
renormalize,
routed_scaling_factor,
)
return topk_weights, topk_ids
else:
return biased_grouped_topk_impl(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
def biased_grouped_topk_cpu(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
compiled: bool = True,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
apply_routed_scaling_factor_on_output: Optional[bool] = False,
):
assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
hidden_states,
gating_output,
correction_bias,
topk,
renormalize,
num_expert_group,
topk_group,
num_fused_shared_experts,
routed_scaling_factor,
num_token_non_padded,
)
if _is_cpu and _is_cpu_amx_available:
biased_grouped_topk = biased_grouped_topk_cpu
grouped_topk = grouped_topk_cpu
fused_topk_native = fused_topk_cpu
fused_topk = fused_topk_cpu
else:
biased_grouped_topk = biased_grouped_topk_gpu
grouped_topk = grouped_topk_gpu
fused_topk_native = fused_topk_torch_native
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
topk_config: TopKConfig,
*,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
) -> StandardTopKOutput:
top_k = topk_config.top_k
use_grouped_topk = topk_config.use_grouped_topk
topk_group = topk_config.topk_group
num_expert_group = topk_config.num_expert_group
renormalize = topk_config.renormalize
num_fused_shared_experts = topk_config.num_fused_shared_experts
custom_routing_function = topk_config.custom_routing_function
correction_bias = topk_config.correction_bias
torch_native = topk_config.torch_native
routed_scaling_factor = topk_config.routed_scaling_factor
apply_routed_scaling_factor_on_output = (
topk_config.apply_routed_scaling_factor_on_output
)
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
router_logits=router_logits,
correction_bias=correction_bias,
info=expert_location_dispatch_info,
)
)
# DeepSeek V2/V3/R1 series models use grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if correction_bias is None:
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
)
elif torch_native and custom_routing_function is None:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native"
assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
correction_bias=correction_bias,
)
elif custom_routing_function is None:
assert not apply_routed_scaling_factor_on_output, "Not implemented"
# Qwen3MOE uses fused_topk
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
)
else:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in custom_routing_function"
assert expert_location_dispatch_info is None
assert not apply_routed_scaling_factor_on_output, "Not implemented"
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
return StandardTopKOutput(topk_weights, topk_ids, router_logits)