685 lines
24 KiB
Python
685 lines
24 KiB
Python
# Copyright 2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from sglang.srt.custom_op import CustomOp
|
|
from sglang.srt.eplb import expert_location_dispatch
|
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
from sglang.srt.eplb.expert_location_dispatch import (
|
|
ExpertLocationDispatchInfo,
|
|
topk_ids_logical_to_physical,
|
|
)
|
|
from sglang.srt.utils import (
|
|
cpu_has_amx_support,
|
|
get_bool_env_var,
|
|
get_compiler_backend,
|
|
is_cpu,
|
|
is_cuda,
|
|
is_hip,
|
|
is_npu,
|
|
)
|
|
|
|
_is_cuda = is_cuda()
|
|
_is_hip = is_hip()
|
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
_is_cpu_amx_available = cpu_has_amx_support()
|
|
_is_cpu = is_cpu()
|
|
_is_npu = is_npu()
|
|
|
|
if _is_cuda:
|
|
from sgl_kernel import moe_fused_gate
|
|
|
|
if _is_cuda or _is_hip:
|
|
from sgl_kernel import topk_softmax
|
|
if _use_aiter:
|
|
try:
|
|
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
|
except ImportError:
|
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
|
|
|
if _is_npu:
|
|
import torch_npu
|
|
|
|
|
|
class TopKOutput(NamedTuple):
|
|
topk_weights: torch.Tensor
|
|
topk_ids: torch.Tensor
|
|
router_logits: torch.Tensor
|
|
|
|
|
|
class TopK(CustomOp):
|
|
|
|
# TODO(ch-wan): support triton_kernels
|
|
|
|
def __init__(
|
|
self,
|
|
top_k: int,
|
|
*,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
renormalize: bool = True,
|
|
num_fused_shared_experts: int = 0,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
):
|
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
|
super().__init__()
|
|
if use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
self.top_k = top_k
|
|
self.use_grouped_topk = use_grouped_topk
|
|
self.renormalize = renormalize
|
|
self.topk_group = topk_group
|
|
self.num_expert_group = num_expert_group
|
|
self.num_fused_shared_experts = num_fused_shared_experts
|
|
self.custom_routing_function = custom_routing_function
|
|
self.correction_bias = correction_bias
|
|
self.routed_scaling_factor = routed_scaling_factor
|
|
|
|
def forward_native(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
torch_native = True
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
renormalize=self.renormalize,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
custom_routing_function=self.custom_routing_function,
|
|
correction_bias=self.correction_bias,
|
|
torch_native=torch_native,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
torch_native = False
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
renormalize=self.renormalize,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
custom_routing_function=self.custom_routing_function,
|
|
correction_bias=self.correction_bias,
|
|
torch_native=torch_native,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
def forward_cpu(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
renormalize=self.renormalize,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
custom_routing_function=self.custom_routing_function,
|
|
correction_bias=self.correction_bias,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
def forward_npu(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
global_num_experts = router_logits.shape[-1]
|
|
|
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
|
if global_num_experts == 256:
|
|
return torch_npu.npu_moe_gating_top_k(
|
|
router_logits,
|
|
k=self.top_k,
|
|
bias=self.correction_bias,
|
|
k_group=self.topk_group,
|
|
group_count=self.num_expert_group,
|
|
group_select_mode=1,
|
|
renorm=0,
|
|
norm_type=1,
|
|
routed_scaling_factor=1,
|
|
eps=float(1e-20),
|
|
)
|
|
else:
|
|
torch_native = True
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
top_k=self.top_k,
|
|
use_grouped_topk=self.use_grouped_topk,
|
|
renormalize=self.renormalize,
|
|
topk_group=self.topk_group,
|
|
num_expert_group=self.num_expert_group,
|
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
|
custom_routing_function=self.custom_routing_function,
|
|
correction_bias=self.correction_bias,
|
|
torch_native=torch_native,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
|
|
def fused_topk_torch_native(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
):
|
|
assert (
|
|
hidden_states.shape[0] == gating_output.shape[0]
|
|
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
|
M, _ = hidden_states.shape
|
|
topk_weights = torch.empty(
|
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def fused_topk_cpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
|
hidden_states=hidden_states,
|
|
gating_output=gating_output,
|
|
topk=topk,
|
|
renormalize=renormalize,
|
|
)
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
|
|
if not need_apply:
|
|
return inputs, topk_weights
|
|
|
|
# TODO: fuse below processing in fused_experts_cpu kernel
|
|
inputs = inputs * topk_weights.to(inputs.dtype)
|
|
topk_weights = torch.ones_like(
|
|
topk_weights, dtype=torch.float32
|
|
) # clear topk_weights as already applied
|
|
|
|
return inputs, topk_weights
|
|
|
|
|
|
def fused_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
|
|
M, _ = hidden_states.shape
|
|
|
|
topk_weights = torch.empty(
|
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
|
|
topk_softmax(
|
|
topk_weights,
|
|
topk_ids,
|
|
gating_output,
|
|
renormalize,
|
|
)
|
|
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
# This is used by the Deepseek V2/V3/R1 series models
|
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
def grouped_topk_gpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
|
|
scores = torch.softmax(gating_output, dim=-1)
|
|
# NPU compiler limitation
|
|
if _is_npu and scores.dtype == torch.bfloat16:
|
|
scores = scores.to(torch.float16)
|
|
num_token = scores.shape[0]
|
|
num_experts = scores.shape[1]
|
|
group_scores = (
|
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
) # [n, n_group]
|
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
1
|
|
] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
.reshape(num_token, -1)
|
|
) # [n, e]
|
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
if num_fused_shared_experts:
|
|
topk_ids[:, -1] = torch.randint(
|
|
low=num_experts,
|
|
high=num_experts + num_fused_shared_experts,
|
|
size=(topk_ids.size(0),),
|
|
dtype=topk_ids.dtype,
|
|
device=topk_ids.device,
|
|
)
|
|
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
|
|
|
if renormalize:
|
|
topk_weights_sum = (
|
|
topk_weights.sum(dim=-1, keepdim=True)
|
|
if num_fused_shared_experts == 0
|
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
|
)
|
|
topk_weights = topk_weights / topk_weights_sum
|
|
|
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def grouped_topk_cpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
assert expert_location_dispatch_info is None
|
|
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
|
hidden_states,
|
|
gating_output,
|
|
topk,
|
|
renormalize,
|
|
num_expert_group,
|
|
topk_group,
|
|
num_fused_shared_experts,
|
|
routed_scaling_factor,
|
|
num_token_non_padded,
|
|
)
|
|
|
|
|
|
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
|
def biased_grouped_topk_impl(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
|
|
scores = gating_output.sigmoid()
|
|
num_token = scores.shape[0]
|
|
num_experts = scores.shape[1]
|
|
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
|
group_scores = (
|
|
scores_for_choice.view(num_token, num_expert_group, -1)
|
|
.topk(2, dim=-1)[0]
|
|
.sum(dim=-1)
|
|
) # [n, n_group]
|
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
1
|
|
] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
.reshape(num_token, -1)
|
|
) # [n, e]
|
|
tmp_scores = scores_for_choice.masked_fill(
|
|
~score_mask.bool(), float("-inf")
|
|
) # [n, e]
|
|
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
topk_weights = scores.gather(1, topk_ids)
|
|
|
|
if num_fused_shared_experts:
|
|
topk_ids[:, -1] = torch.randint(
|
|
low=num_experts,
|
|
high=num_experts + num_fused_shared_experts,
|
|
size=(topk_ids.size(0),),
|
|
dtype=topk_ids.dtype,
|
|
device=topk_ids.device,
|
|
)
|
|
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
|
|
|
if renormalize:
|
|
topk_weights_sum = (
|
|
topk_weights.sum(dim=-1, keepdim=True)
|
|
if num_fused_shared_experts == 0
|
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
|
)
|
|
topk_weights = topk_weights / topk_weights_sum
|
|
|
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def is_power_of_two(n):
|
|
return n > 0 and math.log2(n).is_integer()
|
|
|
|
|
|
def _mask_topk_ids_padded_region(
|
|
topk_ids: torch.Tensor,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
):
|
|
if num_token_non_padded is None:
|
|
return
|
|
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
|
|
topk_ids[indices >= num_token_non_padded, :] = -1
|
|
|
|
|
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
def _biased_grouped_topk_postprocess(
|
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
|
):
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_ids
|
|
|
|
|
|
def biased_grouped_topk_gpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
assert (
|
|
routed_scaling_factor is not None
|
|
), "routed_scaling_factor is required for biased_grouped_topk"
|
|
# TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.
|
|
if (
|
|
_is_cuda
|
|
and gating_output.shape[1] // num_expert_group
|
|
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
|
and is_power_of_two(correction_bias.shape[0])
|
|
):
|
|
topk_weights, topk_ids = moe_fused_gate(
|
|
gating_output.to(dtype=torch.float32),
|
|
correction_bias,
|
|
num_expert_group,
|
|
topk_group,
|
|
topk,
|
|
num_fused_shared_experts,
|
|
routed_scaling_factor,
|
|
)
|
|
# TODO merge into kernel
|
|
if (expert_location_dispatch_info is not None) or (
|
|
num_token_non_padded is not None
|
|
):
|
|
topk_ids = _biased_grouped_topk_postprocess(
|
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
|
)
|
|
return topk_weights, topk_ids
|
|
elif _use_aiter:
|
|
token = gating_output.shape[0]
|
|
device = gating_output.device
|
|
assert (
|
|
hidden_states.shape[0] == gating_output.shape[0]
|
|
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
|
|
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
|
aiter_biased_grouped_topk(
|
|
gating_output.to(dtype=torch.float32),
|
|
correction_bias,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
renormalize,
|
|
routed_scaling_factor,
|
|
)
|
|
return topk_weights, topk_ids
|
|
else:
|
|
return biased_grouped_topk_impl(
|
|
hidden_states,
|
|
gating_output,
|
|
correction_bias,
|
|
topk,
|
|
renormalize,
|
|
num_expert_group,
|
|
topk_group,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
|
|
def biased_grouped_topk_cpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: int = 0,
|
|
topk_group: int = 0,
|
|
compiled: bool = True,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
assert expert_location_dispatch_info is None
|
|
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
|
hidden_states,
|
|
gating_output,
|
|
correction_bias,
|
|
topk,
|
|
renormalize,
|
|
num_expert_group,
|
|
topk_group,
|
|
num_fused_shared_experts,
|
|
routed_scaling_factor,
|
|
num_token_non_padded,
|
|
)
|
|
|
|
|
|
if _is_cpu and _is_cpu_amx_available:
|
|
biased_grouped_topk = biased_grouped_topk_cpu
|
|
grouped_topk = grouped_topk_cpu
|
|
fused_topk_native = fused_topk_cpu
|
|
fused_topk = fused_topk_cpu
|
|
else:
|
|
biased_grouped_topk = biased_grouped_topk_gpu
|
|
grouped_topk = grouped_topk_gpu
|
|
fused_topk_native = fused_topk_torch_native
|
|
|
|
|
|
def select_experts(
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
*,
|
|
use_grouped_topk: bool = False,
|
|
renormalize: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
num_fused_shared_experts: int = 0,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
torch_native: bool = False,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
router_logits, correction_bias = (
|
|
expert_location_dispatch.transform_select_experts_inputs(
|
|
router_logits=router_logits,
|
|
correction_bias=correction_bias,
|
|
info=expert_location_dispatch_info,
|
|
)
|
|
)
|
|
|
|
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
|
if use_grouped_topk:
|
|
assert topk_group is not None
|
|
assert num_expert_group is not None
|
|
if correction_bias is None:
|
|
topk_weights, topk_ids = grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
else:
|
|
topk_weights, topk_ids = biased_grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
correction_bias=correction_bias,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
elif torch_native and custom_routing_function is None:
|
|
assert (
|
|
num_token_non_padded is None
|
|
), "num_token_non_padded is not yet supported in fused_topk_native"
|
|
assert expert_location_dispatch_info is None
|
|
topk_weights, topk_ids = fused_topk_native(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
)
|
|
elif custom_routing_function is None:
|
|
# Qwen3MOE uses fused_topk
|
|
topk_weights, topk_ids = fused_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
else:
|
|
assert (
|
|
num_token_non_padded is None
|
|
), "num_token_non_padded is not yet supported in custom_routing_function"
|
|
assert expert_location_dispatch_info is None
|
|
topk_weights, topk_ids = custom_routing_function(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
)
|
|
|
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
|
|
|
return TopKOutput(topk_weights, topk_ids, router_logits)
|