first commit

This commit is contained in:
2026-03-10 13:31:25 +08:00
parent ba974cecfa
commit b62b889355
2604 changed files with 438977 additions and 0 deletions

View File

@@ -0,0 +1,23 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from . import layer, supa_moe # noqa: E402
from .layer import * # noqa: E402
__all__ = [
"layer",
"supa_moe",
]

View File

@@ -0,0 +1,413 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Callable, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch_br.utils.tensor_methods import Sbp
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.utils import set_weight_attrs
from vllm_br import envs
from ..br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, align_n, cross_weight_32)
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
fused_oss_moe_dyn)
@patch_to(UnquantizedFusedMoEMethod)
def forward_oot(
self,
layer: FusedMoE,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
"""
if activation == "swigluoai":
return fused_oss_moe_dyn(
x,
layer.w13_weight,
layer.w13_bias,
layer.w2_weight,
layer.w2_bias,
router_logits,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
b_seq = x.shape[0]
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
# prefill
return fused_moe_quant_dyn(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
else:
# decoder
return fused_moe_quant_device(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
@patch_to(UnquantizedFusedMoEMethod)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
@patch_to(UnquantizedFusedMoEMethod)
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
layer: FusedMoE) -> None:
cur_device = torch.supa.current_device()
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
die_num = 1 if die_spc_num <= 16 else 2
spc_num = die_spc_num // die_num
align_size = 32 if layer.activation == "swigluoai" else 64
is_dual_die = (die_spc_num > 16)
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
wk = layer.hidden_size
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
align_size=align_size,
spc_num=spc_num)
supa_w13_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
'COLMAJOR',
expert_w13.dtype)
pad_expert_w13_shape = pad_expert_w13.shape
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
else:
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
expert_3,
die_spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR')
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight.data = supa_w13_weight
# NOTE: w13_bias
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
wn = layer.intermediate_size_per_partition * 2
supa_w13_bias = torch_br._empty_ut_only(
size=(layer.local_num_experts, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias",
sbp="BB" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13_bias = layer.w13_bias[expert_id]
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(expert_w13_bias)
else:
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
crossed_expert_w13_bias = cross_weight_32(
expert_1_bias,
expert_3_bias,
die_spc_num,
dim=0,
need_pad=False,
)
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(crossed_expert_w13_bias)
layer.w13_bias.data = supa_w13_bias
# NOTE: w2_weight
# after _load_w2, w2_weight is a rowparallel weight, shape
# [num_experts, hidden_size, intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block]
align_size = 32
wk = layer.intermediate_size_per_partition
wn_block = align_n(layer.hidden_size,
align_size=align_size,
spc_num=spc_num)
supa_w2_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
align_size,
'COLMAJOR',
expert_w2.dtype,
parallel_type="row_parallel")
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
pad_expert_w2_shape,
Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight.data = supa_w2_weight
# NOTE: w2_bias
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
wn = layer.hidden_size
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
dtype=torch.float32,
device=cur_device)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_bias[expert_id]
narrow_data = supa_w2_bias[expert_id]
narrow_data.copy_(expert_w2)
layer.w2_bias.data = supa_w2_bias
@patch_to(FusedMoE)
def forward(self: FusedMoE, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
"""
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
shared_experts.down_proj weights.
"""
assert self.quant_method is not None
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
tp_size = get_tensor_model_parallel_world_size()
if hidden_states.shape[
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
final_hidden_states.all_reduced = True
return final_hidden_states
@patch_to(FusedMoE)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight.cpu())
@patch_to(FusedMoE)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight.cpu())
def wrapper_FusedMoE_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.e_score_correction_bias is not None:
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
)
return wrapper
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501

View File

@@ -0,0 +1,518 @@
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from typing import Callable, Optional
import torch
import torch_br
from torch_br.utils.tensor_methods import Sbp
from vllm_br import envs
# gpt-oss moe forward version
def fused_oss_moe_dyn(
hidden_states: torch.Tensor,
w13: torch.Tensor,
w13_bias: torch.Tensor,
w2: torch.Tensor,
w2_bias: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
total_expert_num = gating_weight.shape[-2]
probs_supa, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
hidden_states,
gating_weight,
topk,
ep_size,
ep_rank,
gating_bias=e_score_correction_bias)
cur_device = hidden_states.device
probs_supa = probs_supa.cpu().permute(1, 0).contiguous().to(cur_device)
indices_supa = indices_supa.cpu().permute(1, 0).contiguous().to(cur_device)
indice_per_expert = indice_per_expert.cpu().permute(
1, 0).contiguous().to(cur_device)
prob_per_expert = prob_per_expert.cpu().permute(
1, 0).contiguous().to(cur_device)
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
local_expert_num = total_expert_num // ep_size # type: ignore
b_seq = hidden_states.shape[0]
indices_trans_supa = torch_br._empty_ut_only(
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
dtype=torch.int32,
is_numa=False,
device=hidden_states.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
minlength=total_expert_num)
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
local_expert_num]
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
if topk_per_expert > 0:
expert_tokens = torch_br.supa_permutation_infer(
global_hidden_states=hidden_states,
indices=indice_per_expert,
tokens_per_expert=tokens_per_expert_list,
indices_trans=indices_trans_supa)
assert len(
expert_tokens) == local_expert_num, "Number of experts mismatch"
gate_up_outputs = []
down_outputs = []
cur_device = expert_tokens[0].device
hidden_size = expert_tokens[0].shape[-1]
for i in range(local_expert_num):
if tokens_per_expert_list[i] == 0:
gate_up_outputs.append(
torch.empty(size=(0, intermediate_size),
dtype=torch.bfloat16,
device=cur_device))
down_outputs.append(
torch.empty(size=(0, hidden_size),
dtype=torch.float32,
device=cur_device))
continue
gate_up_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
gate_up_outputs.append(gate_up_output)
down_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], hidden_size),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
down_outputs.append(down_output)
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
expert_tokens,
w13,
tokens_per_expert_list,
max(tokens_per_expert_list),
bias=w13_bias,
act_mode="act_swiglu_oai")
torch_br.supa_moe_fused_ffn_dyn_infer(down_outputs,
gate_up_outputs,
w2,
tokens_per_expert_list,
max(tokens_per_expert_list),
bias=w2_bias,
act_mode="act_default")
output = torch_br.supa_unpermutation_infer(
input_list=down_outputs,
indices=indices_trans_supa,
probs=prob_per_expert,
tokens_per_expert=tokens_per_expert_list)
else:
output = torch.zeros_like(hidden_states)
return output.unsqueeze(0)
def fused_moe_quant_dyn(
hidden_states: torch.Tensor,
shared_gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
global_rank: Optional[int] = None,
tp_size: Optional[int] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
total_expert_num = gating_weight.shape[-1]
cur_device = hidden_states.device
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
shared_output, _, indices_supa, indice_per_expert, prob_per_expert = torch_br.supa_fused_shared_router_prefill_v2_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
ep_size,
ep_rank,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
if is_dual_die:
shared_tmp = torch_br._empty_ut_only(size=shared_output.shape,
dtype=shared_output.dtype,
is_numa=False,
device=shared_output.device,
tensor_type="colmajor")
shared_tmp.copy_(shared_output)
shared_output = shared_tmp
else:
assert topk_group is None, "Only support non group topk router"
assert shared_gate_up_weight is None and down_weight is None
_, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
hidden_states,
gating_weight.permute(1, 0).contiguous(),
topk,
ep_size,
ep_rank,
gating_bias=e_score_correction_bias)
shared_output = None
indices_supa = indices_supa.permute(1, 0).contiguous()
indice_per_expert = indice_per_expert.permute(1, 0).contiguous()
prob_per_expert = prob_per_expert.permute(1, 0).contiguous()
local_expert_num = total_expert_num // ep_size # type: ignore
b_seq = hidden_states.shape[0]
indices_trans_supa = torch_br._empty_ut_only(
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
dtype=torch.int32,
is_numa=False,
device=hidden_states.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
minlength=total_expert_num)
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
local_expert_num]
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
if topk_per_expert > 0:
expert_tokens = torch_br.supa_permutation_infer(
global_hidden_states=hidden_states,
indices=indice_per_expert,
tokens_per_expert=tokens_per_expert_list,
indices_trans=indices_trans_supa)
assert len(
expert_tokens) == local_expert_num, "Number of experts mismatch"
supa_device = torch.supa.current_device()
spc_num = torch_br.supa.get_device_properties(
supa_device).max_compute_units
out_expert_tokens = []
use_moe_fused_ffn_dyn = True
if not use_moe_fused_ffn_dyn or total_expert_num == 128:
w13_hw = w13.shape[-2] * w13.shape[-1]
w2_hw = w2.shape[-2] * w2.shape[-1]
for i in range(local_expert_num):
expert_token = expert_tokens[i]
if tokens_per_expert_list[i] == 0:
out_expert_tokens.append(expert_token)
continue
expert_gate_up_weight = w13.view_as_usharp(
"COLMAJOR", (spc_num, w13.shape[-2], w13.shape[-1]),
Sbp.ss(0), i * w13_hw)
down_weight = w2.view_as_usharp(
"COLMAJOR", (spc_num, w2.shape[-2], w2.shape[-1]),
Sbp.ss(0), i * w2_hw)
expert_gate_up_scale = w13_scale[
i] if w13_scale is not None else None
down_scale = w2_scale[i] if w2_scale is not None else None
gate_up_output = torch_br._empty_ut_only(
size=(expert_token.shape[0], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=expert_token.device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
torch_br.supa_fused_linear_infer(gate_up_output,
expert_token,
expert_gate_up_weight,
expert_gate_up_scale,
act_mode="act_swiglu")
down_output = torch_br._empty_ut_only(
size=expert_token.shape,
dtype=torch.float32,
is_numa=False,
device=gate_up_output.device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
torch_br.supa_fused_linear_infer(down_output, gate_up_output,
down_weight, down_scale)
out_expert_tokens.append(down_output)
else:
gate_up_outputs = []
cur_device = expert_tokens[0].device
hidden_size = expert_tokens[0].shape[-1]
for i in range(local_expert_num):
if tokens_per_expert_list[i] == 0:
gate_up_outputs.append(
torch.empty(size=(0, intermediate_size),
dtype=torch.bfloat16,
device=cur_device))
out_expert_tokens.append(
torch.empty(size=(0, hidden_size),
dtype=torch.float32,
device=cur_device))
continue
gate_up_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], intermediate_size),
dtype=torch.bfloat16,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="SB" if is_dual_die else None,
axis=1)
gate_up_outputs.append(gate_up_output)
down_output = torch_br._empty_ut_only(
size=(tokens_per_expert_list[i], hidden_size),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="colmajor",
sbp="BB" if is_dual_die else None)
out_expert_tokens.append(down_output)
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
expert_tokens,
w13,
tokens_per_expert_list,
max(tokens_per_expert_list),
scales=w13_scale,
act_mode="act_swiglu")
torch_br.supa_moe_fused_ffn_dyn_infer(out_expert_tokens,
gate_up_outputs,
w2,
tokens_per_expert_list,
max(tokens_per_expert_list),
scales=w2_scale,
act_mode="act_default")
out_states = torch_br.supa_unpermutation_infer(
input_list=out_expert_tokens,
indices=indices_trans_supa,
probs=prob_per_expert,
tokens_per_expert=tokens_per_expert_list)
output = out_states if shared_output is None else out_states + shared_output
else:
output = torch.zeros_like(
hidden_states) if shared_output is None else shared_output
return output.unsqueeze(0)
def fused_moe_quant_device(
hidden_states: torch.Tensor,
shared_gate_up_weight: torch.Tensor,
down_weight: torch.Tensor,
w13: torch.Tensor,
w2: torch.Tensor,
w13_scale: torch.Tensor,
w2_scale: torch.Tensor,
gating_weight: torch.Tensor,
topk: int,
intermediate_size: int,
renormalize: bool,
inplace: bool = False,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = None,
global_rank: Optional[int] = None,
tp_size: Optional[int] = None,
ep_rank: Optional[int] = None,
ep_size: Optional[int] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- num_expert_group: Optional[int]: additional parameter for grouped_topk
- topk_group: Optional[int]: additional parameter for grouped_topk
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
note: Deepseekv2 model uses grouped_topk
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
expert_num = gating_weight.shape[-1]
b_seq = hidden_states.shape[-2]
if topk_group is None:
assert shared_gate_up_weight is None and down_weight is None
shared_output, masked_probs, hitted_experts = torch_br.supa_moe_router_decoder_infer(
hidden_states, gating_weight, topk, ep_size, ep_rank)
else:
assert use_grouped_topk is True, "Only support group topk router"
assert num_expert_group is not None and topk_group is not None
if ep_size > 1: # type: ignore
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_v2_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
ep_size,
ep_rank,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias
if e_score_correction_bias is not None else torch.empty(
(expert_num),
dtype=torch.float32,
device=hidden_states.device))
else:
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_infer(
hidden_states,
shared_gate_up_weight,
down_weight,
gating_weight,
intermediate_size,
topk,
num_expert_group,
topk_group,
scoring_func,
e_score_correction_bias=e_score_correction_bias
if e_score_correction_bias is not None else torch.empty(
(expert_num),
dtype=torch.float32,
device=hidden_states.device))
if is_dual_die:
shared_output = shared_output.view_as_usharp(
"COLMAJOR", shared_output.shape, Sbp.bb())
if w13.dtype == torch.int32:
torch_br.supa_moe_fused_ffn_s4_infer(shared_output, hidden_states, w13,
w2, hitted_experts, masked_probs,
w13_scale, w2_scale)
else:
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and b_seq <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
# ffn+allreduce only support tp 4|8 and 16spc
torch_br.supa_moe_fused_ffn_allreduce(shared_output, hidden_states,
w13, w2, hitted_experts,
masked_probs, tp_rank,
tp_size, global_rank, 0,
w13_scale, w2_scale)
else:
torch_br.supa_moe_fused_ffn_infer(shared_output, hidden_states,
w13, w2, hitted_experts,
masked_probs, w13_scale,
w2_scale)
return shared_output.unsqueeze(0)