first commit
This commit is contained in:
23
vllm_br/model_executor/layers/fused_moe/__init__.py
Normal file
23
vllm_br/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from . import layer, supa_moe # noqa: E402
|
||||
from .layer import * # noqa: E402
|
||||
|
||||
__all__ = [
|
||||
"layer",
|
||||
"supa_moe",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
413
vllm_br/model_executor/layers/fused_moe/layer.py
Normal file
413
vllm_br/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,413 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from fastcore.basics import patch_to
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm_br import envs
|
||||
from ..br_utils import (_convert_to_crossed_numa_tensor,
|
||||
_convert_to_numa_tensor, align_n, cross_weight_32)
|
||||
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
|
||||
fused_oss_moe_dyn)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def forward_oot(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
|
||||
"""
|
||||
if activation == "swigluoai":
|
||||
return fused_oss_moe_dyn(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w13_bias,
|
||||
layer.w2_weight,
|
||||
layer.w2_bias,
|
||||
router_logits,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
|
||||
b_seq = x.shape[0]
|
||||
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
|
||||
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
|
||||
# prefill
|
||||
return fused_moe_quant_dyn(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
else:
|
||||
# decoder
|
||||
return fused_moe_quant_device(
|
||||
x,
|
||||
shared_gate_up_weight,
|
||||
shared_down_weight,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
gating_weight,
|
||||
top_k,
|
||||
layer.intermediate_size_per_partition,
|
||||
renormalize=renormalize,
|
||||
inplace=True,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
tp_rank=get_tp_group().rank_in_group,
|
||||
global_rank=get_tp_group().rank,
|
||||
tp_size=get_tensor_model_parallel_world_size(),
|
||||
ep_rank=layer.ep_rank,
|
||||
ep_size=layer.ep_size)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
if self.moe.has_bias:
|
||||
w13_bias = torch.nn.Parameter(torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
# down_proj (row parallel)
|
||||
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
if self.moe.has_bias:
|
||||
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||||
hidden_size,
|
||||
device="cpu",
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
|
||||
@patch_to(UnquantizedFusedMoEMethod)
|
||||
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
|
||||
layer: FusedMoE) -> None:
|
||||
cur_device = torch.supa.current_device()
|
||||
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||||
die_num = 1 if die_spc_num <= 16 else 2
|
||||
spc_num = die_spc_num // die_num
|
||||
align_size = 32 if layer.activation == "swigluoai" else 64
|
||||
is_dual_die = (die_spc_num > 16)
|
||||
|
||||
# NOTE: w13_weight
|
||||
# after _load_w13, w13_weight is a colparallel weight, shape
|
||||
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
|
||||
wk = layer.hidden_size
|
||||
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w13_weight = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
|
||||
# swigluoai activation, no need do interweave
|
||||
if layer.activation and layer.activation == "swigluoai":
|
||||
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
|
||||
'COLMAJOR',
|
||||
expert_w13.dtype)
|
||||
pad_expert_w13_shape = pad_expert_w13.shape
|
||||
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
|
||||
narrow_data = supa_w13_weight.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
else:
|
||||
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
|
||||
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
|
||||
expert_3,
|
||||
die_spc_num,
|
||||
dim=1,
|
||||
need_pad=True,
|
||||
layout='COLMAJOR')
|
||||
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||||
narrow_data = supa_w13_weight.view_as_usharp(
|
||||
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w13)
|
||||
|
||||
layer.w13_weight.data = supa_w13_weight
|
||||
|
||||
# NOTE: w13_bias
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
wn = layer.intermediate_size_per_partition * 2
|
||||
supa_w13_bias = torch_br._empty_ut_only(
|
||||
size=(layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="linear_bias",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w13_bias = layer.w13_bias[expert_id]
|
||||
# swigluoai activation, no need do interweave
|
||||
if layer.activation and layer.activation == "swigluoai":
|
||||
narrow_data = supa_w13_bias[expert_id]
|
||||
narrow_data.copy_(expert_w13_bias)
|
||||
else:
|
||||
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
|
||||
crossed_expert_w13_bias = cross_weight_32(
|
||||
expert_1_bias,
|
||||
expert_3_bias,
|
||||
die_spc_num,
|
||||
dim=0,
|
||||
need_pad=False,
|
||||
)
|
||||
narrow_data = supa_w13_bias[expert_id]
|
||||
narrow_data.copy_(crossed_expert_w13_bias)
|
||||
layer.w13_bias.data = supa_w13_bias
|
||||
|
||||
# NOTE: w2_weight
|
||||
# after _load_w2, w2_weight is a rowparallel weight, shape
|
||||
# [num_experts, hidden_size, intermediate_size_per_partition]
|
||||
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||||
# [spc_num * num_experts, wk, wn_block]
|
||||
align_size = 32
|
||||
wk = layer.intermediate_size_per_partition
|
||||
wn_block = align_n(layer.hidden_size,
|
||||
align_size=align_size,
|
||||
spc_num=spc_num)
|
||||
|
||||
supa_w2_weight = torch_br._empty_ut_only(
|
||||
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=True,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
axis=0,
|
||||
sbp="SS" if is_dual_die else None)
|
||||
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
|
||||
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
|
||||
align_size,
|
||||
'COLMAJOR',
|
||||
expert_w2.dtype,
|
||||
parallel_type="row_parallel")
|
||||
pad_expert_w2_shape = pad_expert_w2.shape
|
||||
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||||
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
|
||||
pad_expert_w2_shape,
|
||||
Sbp.ss(0),
|
||||
expert_id * hw_size)
|
||||
narrow_data.copy_(pad_expert_w2)
|
||||
|
||||
layer.w2_weight.data = supa_w2_weight
|
||||
|
||||
# NOTE: w2_bias
|
||||
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||||
wn = layer.hidden_size
|
||||
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
|
||||
dtype=torch.float32,
|
||||
device=cur_device)
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
expert_w2 = layer.w2_bias[expert_id]
|
||||
narrow_data = supa_w2_bias[expert_id]
|
||||
narrow_data.copy_(expert_w2)
|
||||
|
||||
layer.w2_bias.data = supa_w2_bias
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def forward(self: FusedMoE, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
"""
|
||||
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
|
||||
shared_experts.down_proj weights.
|
||||
"""
|
||||
assert self.quant_method is not None
|
||||
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
|
||||
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.expert_map,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if hidden_states.shape[
|
||||
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
|
||||
final_hidden_states.all_reduced = True
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
|
||||
loaded_weight: torch.Tensor, tp_rank: int):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||||
shard_size = expert_data.shape[shard_dim] // 2
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# Narrow parameter and load.
|
||||
# w1, gate_proj: Load into first logical weight of w13.
|
||||
if shard_id == "w1":
|
||||
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||||
# w3, up_proj: Load into second logical weight of w13.
|
||||
else:
|
||||
assert shard_id == "w3"
|
||||
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||||
expert_data.copy_(loaded_weight.cpu())
|
||||
|
||||
|
||||
@patch_to(FusedMoE)
|
||||
def _load_w2(self,
|
||||
expert_data: torch.Tensor,
|
||||
shard_dim: int,
|
||||
loaded_weight: torch.Tensor,
|
||||
tp_rank: int,
|
||||
load_full: bool = False):
|
||||
|
||||
# Index the loaded weight for tp sharding.
|
||||
# down_proj: "RowParallel" so tp sharding on input_dim
|
||||
# Narrow parameter and load.
|
||||
shard_size = expert_data.shape[shard_dim]
|
||||
if not load_full:
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||||
shard_size)
|
||||
# w2, down_proj: Load into only logical weight of w2.
|
||||
expert_data.copy_(loaded_weight.cpu())
|
||||
|
||||
|
||||
def wrapper_FusedMoE_init(fn):
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
fn(self, *args, **kwargs)
|
||||
|
||||
if self.e_score_correction_bias is not None:
|
||||
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501
|
||||
518
vllm_br/model_executor/layers/fused_moe/supa_moe.py
Normal file
518
vllm_br/model_executor/layers/fused_moe/supa_moe.py
Normal file
@@ -0,0 +1,518 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_br
|
||||
from torch_br.utils.tensor_methods import Sbp
|
||||
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
# gpt-oss moe forward version
|
||||
def fused_oss_moe_dyn(
|
||||
hidden_states: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w13_bias: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_bias: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
total_expert_num = gating_weight.shape[-2]
|
||||
probs_supa, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
|
||||
hidden_states,
|
||||
gating_weight,
|
||||
topk,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
gating_bias=e_score_correction_bias)
|
||||
|
||||
cur_device = hidden_states.device
|
||||
probs_supa = probs_supa.cpu().permute(1, 0).contiguous().to(cur_device)
|
||||
indices_supa = indices_supa.cpu().permute(1, 0).contiguous().to(cur_device)
|
||||
indice_per_expert = indice_per_expert.cpu().permute(
|
||||
1, 0).contiguous().to(cur_device)
|
||||
prob_per_expert = prob_per_expert.cpu().permute(
|
||||
1, 0).contiguous().to(cur_device)
|
||||
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
local_expert_num = total_expert_num // ep_size # type: ignore
|
||||
b_seq = hidden_states.shape[0]
|
||||
indices_trans_supa = torch_br._empty_ut_only(
|
||||
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
|
||||
dtype=torch.int32,
|
||||
is_numa=False,
|
||||
device=hidden_states.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
|
||||
minlength=total_expert_num)
|
||||
|
||||
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
|
||||
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
|
||||
local_expert_num]
|
||||
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
|
||||
|
||||
if topk_per_expert > 0:
|
||||
expert_tokens = torch_br.supa_permutation_infer(
|
||||
global_hidden_states=hidden_states,
|
||||
indices=indice_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list,
|
||||
indices_trans=indices_trans_supa)
|
||||
|
||||
assert len(
|
||||
expert_tokens) == local_expert_num, "Number of experts mismatch"
|
||||
|
||||
gate_up_outputs = []
|
||||
down_outputs = []
|
||||
cur_device = expert_tokens[0].device
|
||||
hidden_size = expert_tokens[0].shape[-1]
|
||||
for i in range(local_expert_num):
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
gate_up_outputs.append(
|
||||
torch.empty(size=(0, intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=cur_device))
|
||||
down_outputs.append(
|
||||
torch.empty(size=(0, hidden_size),
|
||||
dtype=torch.float32,
|
||||
device=cur_device))
|
||||
continue
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
gate_up_outputs.append(gate_up_output)
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], hidden_size),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
down_outputs.append(down_output)
|
||||
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
|
||||
expert_tokens,
|
||||
w13,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
bias=w13_bias,
|
||||
act_mode="act_swiglu_oai")
|
||||
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(down_outputs,
|
||||
gate_up_outputs,
|
||||
w2,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
bias=w2_bias,
|
||||
act_mode="act_default")
|
||||
|
||||
output = torch_br.supa_unpermutation_infer(
|
||||
input_list=down_outputs,
|
||||
indices=indices_trans_supa,
|
||||
probs=prob_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list)
|
||||
else:
|
||||
output = torch.zeros_like(hidden_states)
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
def fused_moe_quant_dyn(
|
||||
hidden_states: torch.Tensor,
|
||||
shared_gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
global_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
total_expert_num = gating_weight.shape[-1]
|
||||
cur_device = hidden_states.device
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
shared_output, _, indices_supa, indice_per_expert, prob_per_expert = torch_br.supa_fused_shared_router_prefill_v2_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
if is_dual_die:
|
||||
shared_tmp = torch_br._empty_ut_only(size=shared_output.shape,
|
||||
dtype=shared_output.dtype,
|
||||
is_numa=False,
|
||||
device=shared_output.device,
|
||||
tensor_type="colmajor")
|
||||
shared_tmp.copy_(shared_output)
|
||||
shared_output = shared_tmp
|
||||
else:
|
||||
assert topk_group is None, "Only support non group topk router"
|
||||
assert shared_gate_up_weight is None and down_weight is None
|
||||
_, indices_supa, prob_per_expert, indice_per_expert = torch_br.supa_moe_router_v2_infer(
|
||||
hidden_states,
|
||||
gating_weight.permute(1, 0).contiguous(),
|
||||
topk,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
gating_bias=e_score_correction_bias)
|
||||
shared_output = None
|
||||
|
||||
indices_supa = indices_supa.permute(1, 0).contiguous()
|
||||
indice_per_expert = indice_per_expert.permute(1, 0).contiguous()
|
||||
prob_per_expert = prob_per_expert.permute(1, 0).contiguous()
|
||||
|
||||
local_expert_num = total_expert_num // ep_size # type: ignore
|
||||
b_seq = hidden_states.shape[0]
|
||||
indices_trans_supa = torch_br._empty_ut_only(
|
||||
size=(local_expert_num, (b_seq + 64 - 1) // 64 * 64),
|
||||
dtype=torch.int32,
|
||||
is_numa=False,
|
||||
device=hidden_states.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
tokens_per_expert_supa = torch.bincount(indices_supa.reshape(-1),
|
||||
minlength=total_expert_num)
|
||||
|
||||
tokens_per_expert_list = tokens_per_expert_supa.cpu().data.numpy().tolist(
|
||||
)[ep_rank * local_expert_num:(ep_rank + 1) * # type: ignore
|
||||
local_expert_num]
|
||||
topk_per_expert = sum(1 for x in tokens_per_expert_list if x != 0)
|
||||
|
||||
if topk_per_expert > 0:
|
||||
expert_tokens = torch_br.supa_permutation_infer(
|
||||
global_hidden_states=hidden_states,
|
||||
indices=indice_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list,
|
||||
indices_trans=indices_trans_supa)
|
||||
|
||||
assert len(
|
||||
expert_tokens) == local_expert_num, "Number of experts mismatch"
|
||||
|
||||
supa_device = torch.supa.current_device()
|
||||
spc_num = torch_br.supa.get_device_properties(
|
||||
supa_device).max_compute_units
|
||||
|
||||
out_expert_tokens = []
|
||||
use_moe_fused_ffn_dyn = True
|
||||
if not use_moe_fused_ffn_dyn or total_expert_num == 128:
|
||||
w13_hw = w13.shape[-2] * w13.shape[-1]
|
||||
w2_hw = w2.shape[-2] * w2.shape[-1]
|
||||
|
||||
for i in range(local_expert_num):
|
||||
expert_token = expert_tokens[i]
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
out_expert_tokens.append(expert_token)
|
||||
continue
|
||||
|
||||
expert_gate_up_weight = w13.view_as_usharp(
|
||||
"COLMAJOR", (spc_num, w13.shape[-2], w13.shape[-1]),
|
||||
Sbp.ss(0), i * w13_hw)
|
||||
|
||||
down_weight = w2.view_as_usharp(
|
||||
"COLMAJOR", (spc_num, w2.shape[-2], w2.shape[-1]),
|
||||
Sbp.ss(0), i * w2_hw)
|
||||
|
||||
expert_gate_up_scale = w13_scale[
|
||||
i] if w13_scale is not None else None
|
||||
down_scale = w2_scale[i] if w2_scale is not None else None
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(expert_token.shape[0], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=expert_token.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
torch_br.supa_fused_linear_infer(gate_up_output,
|
||||
expert_token,
|
||||
expert_gate_up_weight,
|
||||
expert_gate_up_scale,
|
||||
act_mode="act_swiglu")
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=expert_token.shape,
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=gate_up_output.device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
|
||||
torch_br.supa_fused_linear_infer(down_output, gate_up_output,
|
||||
down_weight, down_scale)
|
||||
|
||||
out_expert_tokens.append(down_output)
|
||||
else:
|
||||
gate_up_outputs = []
|
||||
cur_device = expert_tokens[0].device
|
||||
hidden_size = expert_tokens[0].shape[-1]
|
||||
for i in range(local_expert_num):
|
||||
if tokens_per_expert_list[i] == 0:
|
||||
gate_up_outputs.append(
|
||||
torch.empty(size=(0, intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
device=cur_device))
|
||||
out_expert_tokens.append(
|
||||
torch.empty(size=(0, hidden_size),
|
||||
dtype=torch.float32,
|
||||
device=cur_device))
|
||||
continue
|
||||
|
||||
gate_up_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], intermediate_size),
|
||||
dtype=torch.bfloat16,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="SB" if is_dual_die else None,
|
||||
axis=1)
|
||||
gate_up_outputs.append(gate_up_output)
|
||||
|
||||
down_output = torch_br._empty_ut_only(
|
||||
size=(tokens_per_expert_list[i], hidden_size),
|
||||
dtype=torch.float32,
|
||||
is_numa=False,
|
||||
device=cur_device,
|
||||
tensor_type="colmajor",
|
||||
sbp="BB" if is_dual_die else None)
|
||||
out_expert_tokens.append(down_output)
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(gate_up_outputs,
|
||||
expert_tokens,
|
||||
w13,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
scales=w13_scale,
|
||||
act_mode="act_swiglu")
|
||||
torch_br.supa_moe_fused_ffn_dyn_infer(out_expert_tokens,
|
||||
gate_up_outputs,
|
||||
w2,
|
||||
tokens_per_expert_list,
|
||||
max(tokens_per_expert_list),
|
||||
scales=w2_scale,
|
||||
act_mode="act_default")
|
||||
|
||||
out_states = torch_br.supa_unpermutation_infer(
|
||||
input_list=out_expert_tokens,
|
||||
indices=indices_trans_supa,
|
||||
probs=prob_per_expert,
|
||||
tokens_per_expert=tokens_per_expert_list)
|
||||
|
||||
output = out_states if shared_output is None else out_states + shared_output
|
||||
else:
|
||||
output = torch.zeros_like(
|
||||
hidden_states) if shared_output is None else shared_output
|
||||
|
||||
return output.unsqueeze(0)
|
||||
|
||||
|
||||
def fused_moe_quant_device(
|
||||
hidden_states: torch.Tensor,
|
||||
shared_gate_up_weight: torch.Tensor,
|
||||
down_weight: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
gating_weight: torch.Tensor,
|
||||
topk: int,
|
||||
intermediate_size: int,
|
||||
renormalize: bool,
|
||||
inplace: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = None,
|
||||
global_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
ep_rank: Optional[int] = None,
|
||||
ep_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
weights, w1 and w2, and top-k gating mechanism.
|
||||
|
||||
Parameters:
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- override_config (Optional[Dict[str, Any]]): Optional override
|
||||
for the kernel configuration.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseekv2 model uses grouped_topk
|
||||
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w1.
|
||||
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
|
||||
w2.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
is_dual_die = (envs.VLLM_BR_DEVICE_SPC_NUM > 16)
|
||||
expert_num = gating_weight.shape[-1]
|
||||
b_seq = hidden_states.shape[-2]
|
||||
if topk_group is None:
|
||||
assert shared_gate_up_weight is None and down_weight is None
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_moe_router_decoder_infer(
|
||||
hidden_states, gating_weight, topk, ep_size, ep_rank)
|
||||
else:
|
||||
assert use_grouped_topk is True, "Only support group topk router"
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
if ep_size > 1: # type: ignore
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_v2_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
ep_size,
|
||||
ep_rank,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias
|
||||
if e_score_correction_bias is not None else torch.empty(
|
||||
(expert_num),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device))
|
||||
else:
|
||||
shared_output, masked_probs, hitted_experts = torch_br.supa_fused_shared_router_infer(
|
||||
hidden_states,
|
||||
shared_gate_up_weight,
|
||||
down_weight,
|
||||
gating_weight,
|
||||
intermediate_size,
|
||||
topk,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias
|
||||
if e_score_correction_bias is not None else torch.empty(
|
||||
(expert_num),
|
||||
dtype=torch.float32,
|
||||
device=hidden_states.device))
|
||||
if is_dual_die:
|
||||
shared_output = shared_output.view_as_usharp(
|
||||
"COLMAJOR", shared_output.shape, Sbp.bb())
|
||||
|
||||
if w13.dtype == torch.int32:
|
||||
torch_br.supa_moe_fused_ffn_s4_infer(shared_output, hidden_states, w13,
|
||||
w2, hitted_experts, masked_probs,
|
||||
w13_scale, w2_scale)
|
||||
else:
|
||||
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||||
if envs.VLLM_BR_USE_FUSED_ALLREDUCE and b_seq <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and (
|
||||
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
|
||||
# ffn+allreduce only support tp 4|8 and 16spc
|
||||
torch_br.supa_moe_fused_ffn_allreduce(shared_output, hidden_states,
|
||||
w13, w2, hitted_experts,
|
||||
masked_probs, tp_rank,
|
||||
tp_size, global_rank, 0,
|
||||
w13_scale, w2_scale)
|
||||
else:
|
||||
torch_br.supa_moe_fused_ffn_infer(shared_output, hidden_states,
|
||||
w13, w2, hitted_experts,
|
||||
masked_probs, w13_scale,
|
||||
w2_scale)
|
||||
|
||||
return shared_output.unsqueeze(0)
|
||||
Reference in New Issue
Block a user