898 lines
32 KiB
Python
898 lines
32 KiB
Python
# Copyright 2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import math
|
|
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Callable,
|
|
NamedTuple,
|
|
Optional,
|
|
Protocol,
|
|
TypeGuard,
|
|
runtime_checkable,
|
|
)
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from sglang.srt.custom_op import CustomOp
|
|
from sglang.srt.eplb import expert_location_dispatch
|
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
from sglang.srt.eplb.expert_location_dispatch import (
|
|
ExpertLocationDispatchInfo,
|
|
topk_ids_logical_to_physical,
|
|
)
|
|
from sglang.srt.layers.moe import (
|
|
get_moe_runner_backend,
|
|
should_use_flashinfer_trtllm_moe,
|
|
)
|
|
from sglang.srt.utils import (
|
|
cpu_has_amx_support,
|
|
get_bool_env_var,
|
|
get_compiler_backend,
|
|
is_cpu,
|
|
is_cuda,
|
|
is_hip,
|
|
is_npu,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.layers.quantization import QuantizationConfig
|
|
|
|
try:
|
|
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
|
|
except ImportError:
|
|
pass
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
_is_cuda = is_cuda()
|
|
_is_hip = is_hip()
|
|
_is_cpu = is_cpu()
|
|
_is_cpu_amx_available = cpu_has_amx_support()
|
|
_is_npu = is_npu()
|
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
|
|
if _is_cuda:
|
|
from sgl_kernel import moe_fused_gate
|
|
|
|
if _is_cuda or _is_hip:
|
|
from sgl_kernel import topk_softmax
|
|
if _use_aiter:
|
|
try:
|
|
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
|
except ImportError:
|
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
|
if _is_npu:
|
|
import torch_npu
|
|
|
|
# -------------------------------- TopKConfig ---------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class TopKConfig:
|
|
top_k: int
|
|
use_grouped_topk: bool = False
|
|
topk_group: Optional[int] = None
|
|
num_expert_group: Optional[int] = None
|
|
renormalize: bool = True
|
|
num_fused_shared_experts: int = 0
|
|
custom_routing_function: Optional[Callable] = None
|
|
correction_bias: Optional[torch.Tensor] = None
|
|
torch_native: bool = False
|
|
routed_scaling_factor: Optional[float] = None
|
|
apply_routed_scaling_factor_on_output: bool = False
|
|
output_format: Optional[TopKOutputFormat] = None
|
|
|
|
|
|
# -------------------------------- TopKOutput ---------------------------------------
|
|
|
|
|
|
class TopKOutputChecker:
|
|
|
|
@staticmethod
|
|
def format_is_standard(topk_output: TopKOutput) -> TypeGuard[StandardTopKOutput]:
|
|
return topk_output.format.is_standard()
|
|
|
|
@staticmethod
|
|
def format_is_triton_kernel(
|
|
topk_output: TopKOutput,
|
|
) -> TypeGuard[TritonKernelTopKOutput]:
|
|
return topk_output.format.is_triton_kernel()
|
|
|
|
@staticmethod
|
|
def format_is_bypassed(topk_output: TopKOutput) -> TypeGuard[BypassedTopKOutput]:
|
|
return topk_output.format.is_bypassed()
|
|
|
|
|
|
class TopKOutputFormat(Enum):
|
|
STANDARD = auto()
|
|
TRITON_KERNEL = auto()
|
|
BYPASSED = auto()
|
|
|
|
def is_standard(self) -> bool:
|
|
return self == TopKOutputFormat.STANDARD
|
|
|
|
def is_triton_kernel(self) -> bool:
|
|
return self == TopKOutputFormat.TRITON_KERNEL
|
|
|
|
def is_bypassed(self) -> bool:
|
|
return self == TopKOutputFormat.BYPASSED
|
|
|
|
|
|
@runtime_checkable
|
|
class TopKOutput(Protocol):
|
|
"""Protocol for top-k outputs in different formats."""
|
|
|
|
@property
|
|
def format(self) -> TopKOutputFormat:
|
|
"""The format of the output."""
|
|
...
|
|
|
|
|
|
class StandardTopKOutput(NamedTuple):
|
|
"""Standard top-k output format."""
|
|
|
|
topk_weights: torch.Tensor
|
|
topk_ids: torch.Tensor
|
|
router_logits: torch.Tensor
|
|
|
|
@property
|
|
def format(self) -> TopKOutputFormat:
|
|
return TopKOutputFormat.STANDARD
|
|
|
|
|
|
class TritonKernelTopKOutput(NamedTuple):
|
|
"""Triton kernel top-k output format."""
|
|
|
|
routing_data: RoutingData
|
|
gather_indx: GatherIndx
|
|
scatter_indx: ScatterIndx
|
|
|
|
@property
|
|
def format(self) -> TopKOutputFormat:
|
|
return TopKOutputFormat.TRITON_KERNEL
|
|
|
|
|
|
class BypassedTopKOutput(NamedTuple):
|
|
"""Bypassed top-k output format."""
|
|
|
|
hidden_states: torch.Tensor
|
|
router_logits: torch.Tensor
|
|
topk_config: TopKConfig
|
|
num_token_non_padded: Optional[torch.Tensor] = None
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None
|
|
|
|
@property
|
|
def format(self) -> TopKOutputFormat:
|
|
return TopKOutputFormat.BYPASSED
|
|
|
|
|
|
# -------------------------------- TopK ---------------------------------------
|
|
|
|
|
|
class TopK(CustomOp):
|
|
|
|
def __init__(
|
|
self,
|
|
top_k: int,
|
|
*,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
renormalize: bool = True,
|
|
num_fused_shared_experts: int = 0,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
|
output_format: Optional[TopKOutputFormat] = None,
|
|
):
|
|
# NOTE: scoring_func is not used for now, but we keep it for future use
|
|
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
|
super().__init__()
|
|
|
|
if use_grouped_topk:
|
|
assert num_expert_group is not None and topk_group is not None
|
|
|
|
if (
|
|
quant_config is not None
|
|
and quant_config.get_name() == "modelopt_fp4"
|
|
and should_use_flashinfer_trtllm_moe()
|
|
):
|
|
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
|
correction_bias = correction_bias.to(torch.bfloat16)
|
|
|
|
self.topk_config = TopKConfig(
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
custom_routing_function=custom_routing_function,
|
|
correction_bias=correction_bias,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
|
output_format=output_format,
|
|
)
|
|
|
|
def forward_native(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
self.topk_config.torch_native = True
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
topk_config=self.topk_config,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
def forward_cuda(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
if self.topk_config.output_format is not None:
|
|
output_format = self.topk_config.output_format
|
|
elif get_moe_runner_backend().is_triton_kernel():
|
|
output_format = TopKOutputFormat.TRITON_KERNEL
|
|
elif (
|
|
should_use_flashinfer_trtllm_moe()
|
|
or get_moe_runner_backend().is_flashinfer_mxfp4()
|
|
):
|
|
output_format = TopKOutputFormat.BYPASSED
|
|
else:
|
|
output_format = TopKOutputFormat.STANDARD
|
|
|
|
if output_format == TopKOutputFormat.TRITON_KERNEL:
|
|
# renormalize=True is equivalent to sm_first=False
|
|
routing_data, gather_idx, scatter_idx = routing(
|
|
router_logits,
|
|
self.topk_config.top_k,
|
|
sm_first=not self.topk_config.renormalize,
|
|
)
|
|
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
|
|
elif output_format == TopKOutputFormat.BYPASSED:
|
|
return BypassedTopKOutput(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
topk_config=self.topk_config,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
else:
|
|
self.topk_config.torch_native = False
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
topk_config=self.topk_config,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
def forward_cpu(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
topk_config=self.topk_config,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
def forward_npu(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> TopKOutput:
|
|
global_num_experts = router_logits.shape[-1]
|
|
|
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
|
if global_num_experts == 256:
|
|
|
|
routed_scaling_factor = self.topk_config.routed_scaling_factor or 1
|
|
router_logits = router_logits.to(torch.float32)
|
|
|
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
|
router_logits,
|
|
k=self.topk_config.top_k,
|
|
bias=self.topk_config.correction_bias.to(torch.float32),
|
|
k_group=self.topk_config.topk_group,
|
|
group_count=self.topk_config.num_expert_group,
|
|
group_select_mode=1,
|
|
renorm=0,
|
|
norm_type=1,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
eps=float(1e-20),
|
|
)
|
|
|
|
if self.topk_config.renormalize:
|
|
topk_weights_sum = (
|
|
topk_weights.sum(dim=-1, keepdim=True)
|
|
if self.topk_config.num_fused_shared_experts == 0
|
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
|
)
|
|
topk_weights = topk_weights / topk_weights_sum
|
|
|
|
if expert_location_dispatch_info is not None:
|
|
topk_ids = topk_ids_logical_to_physical(
|
|
topk_ids, expert_location_dispatch_info
|
|
)
|
|
get_global_expert_distribution_recorder().on_select_experts(
|
|
topk_ids=topk_ids
|
|
)
|
|
|
|
return StandardTopKOutput(topk_weights, topk_ids, _)
|
|
else:
|
|
self.topk_config.torch_native = True
|
|
return select_experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
topk_config=self.topk_config,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
|
|
def empty_topk_output(self, device: torch.device) -> TopKOutput:
|
|
topk = self.topk_config.top_k - self.topk_config.num_fused_shared_experts
|
|
topk_weights = torch.empty((0, topk), dtype=torch.float32, device=device)
|
|
topk_idx = torch.full((0, topk), -1, dtype=torch.int32, device=device)
|
|
router_logits = torch.empty((0, topk), dtype=torch.float32, device=device)
|
|
return StandardTopKOutput(topk_weights, topk_idx, router_logits)
|
|
|
|
|
|
# ------------------------------- TopK implementation -------------------------------------
|
|
|
|
|
|
def fused_topk_torch_native(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
correction_bias: torch.Tensor = None,
|
|
):
|
|
if correction_bias is not None:
|
|
n_routed_experts = gating_output.shape[-1]
|
|
scores = gating_output.softmax(dim=-1)
|
|
scores_for_choice = scores.view(
|
|
-1, n_routed_experts
|
|
) + correction_bias.unsqueeze(0)
|
|
topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1]
|
|
topk_weights = scores.gather(1, topk_ids)
|
|
else:
|
|
assert (
|
|
hidden_states.shape[0] == gating_output.shape[0]
|
|
), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}"
|
|
M, _ = hidden_states.shape
|
|
topk_weights = torch.empty(
|
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
|
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def fused_topk_cpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
correction_bias: torch.Tensor = None,
|
|
):
|
|
topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu(
|
|
hidden_states=hidden_states,
|
|
gating_output=gating_output,
|
|
topk=topk,
|
|
renormalize=renormalize,
|
|
)
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
|
|
if not need_apply:
|
|
return inputs, topk_weights
|
|
|
|
# TODO: fuse below processing in fused_experts_cpu kernel
|
|
inputs = inputs * topk_weights.to(inputs.dtype)
|
|
topk_weights = torch.ones_like(
|
|
topk_weights, dtype=torch.float32
|
|
) # clear topk_weights as already applied
|
|
|
|
return inputs, topk_weights
|
|
|
|
|
|
def fused_topk(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
):
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
|
|
M, _ = hidden_states.shape
|
|
|
|
topk_weights = torch.empty(
|
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
)
|
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
|
|
topk_softmax(
|
|
topk_weights,
|
|
topk_ids,
|
|
gating_output,
|
|
renormalize,
|
|
)
|
|
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
# This is used by the Deepseek V2/V3/R1 series models
|
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
def grouped_topk_gpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
|
):
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
|
|
scores = torch.softmax(gating_output, dim=-1)
|
|
# NPU compiler limitation
|
|
if _is_npu and scores.dtype == torch.bfloat16:
|
|
scores = scores.to(torch.float16)
|
|
num_token = scores.shape[0]
|
|
num_experts = scores.shape[1]
|
|
group_scores = (
|
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
) # [n, n_group]
|
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
1
|
|
] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
.reshape(num_token, -1)
|
|
) # [n, e]
|
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
# TODO: NPU can't support directly evaluating a comparison for now
|
|
topk_weights, topk_ids = torch.topk(
|
|
tmp_scores,
|
|
k=topk,
|
|
dim=-1,
|
|
sorted=(True if num_fused_shared_experts > 0 else False),
|
|
)
|
|
if num_fused_shared_experts:
|
|
topk_ids[:, -1] = torch.randint(
|
|
low=num_experts,
|
|
high=num_experts + num_fused_shared_experts,
|
|
size=(topk_ids.size(0),),
|
|
dtype=topk_ids.dtype,
|
|
device=topk_ids.device,
|
|
)
|
|
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
|
|
|
if renormalize:
|
|
topk_weights_sum = (
|
|
topk_weights.sum(dim=-1, keepdim=True)
|
|
if num_fused_shared_experts == 0
|
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
|
)
|
|
topk_weights = topk_weights / topk_weights_sum
|
|
if apply_routed_scaling_factor_on_output:
|
|
topk_weights *= routed_scaling_factor
|
|
|
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def grouped_topk_cpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
|
):
|
|
assert not apply_routed_scaling_factor_on_output
|
|
assert expert_location_dispatch_info is None
|
|
return torch.ops.sgl_kernel.grouped_topk_cpu(
|
|
hidden_states,
|
|
gating_output,
|
|
topk,
|
|
renormalize,
|
|
num_expert_group,
|
|
topk_group,
|
|
num_fused_shared_experts,
|
|
routed_scaling_factor,
|
|
num_token_non_padded,
|
|
)
|
|
|
|
|
|
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
|
def biased_grouped_topk_impl(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
|
):
|
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
|
|
scores = gating_output.sigmoid()
|
|
num_token = scores.shape[0]
|
|
num_experts = scores.shape[1]
|
|
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
|
group_scores = (
|
|
scores_for_choice.view(num_token, num_expert_group, -1)
|
|
.topk(2, dim=-1)[0]
|
|
.sum(dim=-1)
|
|
) # [n, n_group]
|
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
1
|
|
] # [n, top_k_group]
|
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
score_mask = (
|
|
group_mask.unsqueeze(-1)
|
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
.reshape(num_token, -1)
|
|
) # [n, e]
|
|
tmp_scores = scores_for_choice.masked_fill(
|
|
~score_mask.bool(), float("-inf")
|
|
) # [n, e]
|
|
# TODO: NPU can't support directly evaluating a comparison for now
|
|
_, topk_ids = torch.topk(
|
|
tmp_scores,
|
|
k=topk,
|
|
dim=-1,
|
|
sorted=(True if num_fused_shared_experts > 0 else False),
|
|
)
|
|
topk_weights = scores.gather(1, topk_ids)
|
|
|
|
if num_fused_shared_experts:
|
|
topk_ids[:, -1] = torch.randint(
|
|
low=num_experts,
|
|
high=num_experts + num_fused_shared_experts,
|
|
size=(topk_ids.size(0),),
|
|
dtype=topk_ids.dtype,
|
|
device=topk_ids.device,
|
|
)
|
|
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
|
|
|
if renormalize:
|
|
topk_weights_sum = (
|
|
topk_weights.sum(dim=-1, keepdim=True)
|
|
if num_fused_shared_experts == 0
|
|
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
|
)
|
|
topk_weights = topk_weights / topk_weights_sum
|
|
if apply_routed_scaling_factor_on_output:
|
|
topk_weights *= routed_scaling_factor
|
|
|
|
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def is_power_of_two(n):
|
|
return n > 0 and math.log2(n).is_integer()
|
|
|
|
|
|
def _mask_topk_ids_padded_region(
|
|
topk_ids: torch.Tensor,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
):
|
|
if num_token_non_padded is None:
|
|
return
|
|
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
|
|
topk_ids[indices >= num_token_non_padded, :] = -1
|
|
|
|
|
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
def _biased_grouped_topk_postprocess(
|
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
|
):
|
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
|
return topk_ids
|
|
|
|
|
|
def biased_grouped_topk_gpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
|
):
|
|
assert (
|
|
routed_scaling_factor is not None
|
|
), "routed_scaling_factor is required for biased_grouped_topk"
|
|
# TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.
|
|
if (
|
|
_is_cuda
|
|
and gating_output.shape[1] // num_expert_group
|
|
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
|
and is_power_of_two(correction_bias.shape[0])
|
|
):
|
|
topk_weights, topk_ids = moe_fused_gate(
|
|
gating_output.to(dtype=torch.float32),
|
|
correction_bias,
|
|
num_expert_group,
|
|
topk_group,
|
|
topk,
|
|
num_fused_shared_experts,
|
|
routed_scaling_factor,
|
|
apply_routed_scaling_factor_on_output,
|
|
)
|
|
# TODO merge into kernel
|
|
if (expert_location_dispatch_info is not None) or (
|
|
num_token_non_padded is not None
|
|
):
|
|
topk_ids = _biased_grouped_topk_postprocess(
|
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
|
)
|
|
return topk_weights, topk_ids
|
|
elif _use_aiter:
|
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
|
token = gating_output.shape[0]
|
|
device = gating_output.device
|
|
assert (
|
|
hidden_states.shape[0] == gating_output.shape[0]
|
|
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
|
|
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
|
aiter_biased_grouped_topk(
|
|
gating_output.to(dtype=torch.float32),
|
|
correction_bias,
|
|
topk_weights,
|
|
topk_ids,
|
|
num_expert_group,
|
|
topk_group,
|
|
renormalize,
|
|
routed_scaling_factor,
|
|
)
|
|
return topk_weights, topk_ids
|
|
else:
|
|
return biased_grouped_topk_impl(
|
|
hidden_states,
|
|
gating_output,
|
|
correction_bias,
|
|
topk,
|
|
renormalize,
|
|
num_expert_group,
|
|
topk_group,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
|
)
|
|
|
|
|
|
def biased_grouped_topk_cpu(
|
|
hidden_states: torch.Tensor,
|
|
gating_output: torch.Tensor,
|
|
correction_bias: torch.Tensor,
|
|
topk: int,
|
|
renormalize: bool,
|
|
num_expert_group: Optional[int] = None,
|
|
topk_group: Optional[int] = None,
|
|
compiled: bool = True,
|
|
num_fused_shared_experts: int = 0,
|
|
routed_scaling_factor: Optional[float] = None,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
apply_routed_scaling_factor_on_output: Optional[bool] = False,
|
|
):
|
|
assert expert_location_dispatch_info is None
|
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
|
return torch.ops.sgl_kernel.biased_grouped_topk_cpu(
|
|
hidden_states,
|
|
gating_output,
|
|
correction_bias,
|
|
topk,
|
|
renormalize,
|
|
num_expert_group,
|
|
topk_group,
|
|
num_fused_shared_experts,
|
|
routed_scaling_factor,
|
|
num_token_non_padded,
|
|
)
|
|
|
|
|
|
if _is_cpu and _is_cpu_amx_available:
|
|
biased_grouped_topk = biased_grouped_topk_cpu
|
|
grouped_topk = grouped_topk_cpu
|
|
fused_topk_native = fused_topk_cpu
|
|
fused_topk = fused_topk_cpu
|
|
else:
|
|
biased_grouped_topk = biased_grouped_topk_gpu
|
|
grouped_topk = grouped_topk_gpu
|
|
fused_topk_native = fused_topk_torch_native
|
|
|
|
|
|
def select_experts(
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
topk_config: TopKConfig,
|
|
*,
|
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
|
) -> StandardTopKOutput:
|
|
|
|
top_k = topk_config.top_k
|
|
use_grouped_topk = topk_config.use_grouped_topk
|
|
topk_group = topk_config.topk_group
|
|
num_expert_group = topk_config.num_expert_group
|
|
renormalize = topk_config.renormalize
|
|
num_fused_shared_experts = topk_config.num_fused_shared_experts
|
|
custom_routing_function = topk_config.custom_routing_function
|
|
correction_bias = topk_config.correction_bias
|
|
torch_native = topk_config.torch_native
|
|
routed_scaling_factor = topk_config.routed_scaling_factor
|
|
apply_routed_scaling_factor_on_output = (
|
|
topk_config.apply_routed_scaling_factor_on_output
|
|
)
|
|
|
|
router_logits, correction_bias = (
|
|
expert_location_dispatch.transform_select_experts_inputs(
|
|
router_logits=router_logits,
|
|
correction_bias=correction_bias,
|
|
info=expert_location_dispatch_info,
|
|
)
|
|
)
|
|
|
|
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
|
if use_grouped_topk:
|
|
assert topk_group is not None
|
|
assert num_expert_group is not None
|
|
if correction_bias is None:
|
|
topk_weights, topk_ids = grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
|
)
|
|
else:
|
|
topk_weights, topk_ids = biased_grouped_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
correction_bias=correction_bias,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_expert_group=num_expert_group,
|
|
topk_group=topk_group,
|
|
num_fused_shared_experts=num_fused_shared_experts,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
apply_routed_scaling_factor_on_output=apply_routed_scaling_factor_on_output,
|
|
)
|
|
elif torch_native and custom_routing_function is None:
|
|
assert (
|
|
num_token_non_padded is None
|
|
), "num_token_non_padded is not yet supported in fused_topk_native"
|
|
assert expert_location_dispatch_info is None
|
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
|
topk_weights, topk_ids = fused_topk_native(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
correction_bias=correction_bias,
|
|
)
|
|
elif custom_routing_function is None:
|
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
|
# Qwen3MOE uses fused_topk
|
|
topk_weights, topk_ids = fused_topk(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
num_token_non_padded=num_token_non_padded,
|
|
expert_location_dispatch_info=expert_location_dispatch_info,
|
|
)
|
|
else:
|
|
assert (
|
|
num_token_non_padded is None
|
|
), "num_token_non_padded is not yet supported in custom_routing_function"
|
|
assert expert_location_dispatch_info is None
|
|
assert not apply_routed_scaling_factor_on_output, "Not implemented"
|
|
topk_weights, topk_ids = custom_routing_function(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
topk=top_k,
|
|
renormalize=renormalize,
|
|
)
|
|
|
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
|
|
|
return StandardTopKOutput(topk_weights, topk_ids, router_logits)
|