Files
enginex-biren-vllm/vllm_br/model_executor/layers/fused_moe/layer.py

414 lines
17 KiB
Python
Raw Normal View History

2026-03-10 13:31:25 +08:00
################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
from functools import wraps
from typing import Callable, Optional
import torch
import torch_br
from fastcore.basics import patch_to
from torch_br.utils.tensor_methods import Sbp
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm.model_executor.utils import set_weight_attrs
from vllm_br import envs
from ..br_utils import (_convert_to_crossed_numa_tensor,
_convert_to_numa_tensor, align_n, cross_weight_32)
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
fused_oss_moe_dyn)
@patch_to(UnquantizedFusedMoEMethod)
def forward_oot(
self,
layer: FusedMoE,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
):
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
"""
if activation == "swigluoai":
return fused_oss_moe_dyn(
x,
layer.w13_weight,
layer.w13_bias,
layer.w2_weight,
layer.w2_bias,
router_logits,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
b_seq = x.shape[0]
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
# prefill
return fused_moe_quant_dyn(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
else:
# decoder
return fused_moe_quant_device(
x,
shared_gate_up_weight,
shared_down_weight,
layer.w13_weight,
layer.w2_weight,
None,
None,
gating_weight,
top_k,
layer.intermediate_size_per_partition,
renormalize=renormalize,
inplace=True,
use_grouped_topk=use_grouped_topk,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
tp_rank=get_tp_group().rank_in_group,
global_rank=get_tp_group().rank,
tp_size=get_tensor_model_parallel_world_size(),
ep_rank=layer.ep_rank,
ep_size=layer.ep_size)
@patch_to(UnquantizedFusedMoEMethod)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.moe.has_bias:
w13_bias = torch.nn.Parameter(torch.zeros(
num_experts,
2 * intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_bias", w13_bias)
set_weight_attrs(w13_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size_per_partition,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.moe.has_bias:
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
hidden_size,
device="cpu",
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_bias", w2_bias)
set_weight_attrs(w2_bias, extra_weight_attrs)
@patch_to(UnquantizedFusedMoEMethod)
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
layer: FusedMoE) -> None:
cur_device = torch.supa.current_device()
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
die_num = 1 if die_spc_num <= 16 else 2
spc_num = die_spc_num // die_num
align_size = 32 if layer.activation == "swigluoai" else 64
is_dual_die = (die_spc_num > 16)
# NOTE: w13_weight
# after _load_w13, w13_weight is a colparallel weight, shape
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
wk = layer.hidden_size
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
align_size=align_size,
spc_num=spc_num)
supa_w13_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
'COLMAJOR',
expert_w13.dtype)
pad_expert_w13_shape = pad_expert_w13.shape
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
else:
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
expert_3,
die_spc_num,
dim=1,
need_pad=True,
layout='COLMAJOR')
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
narrow_data = supa_w13_weight.view_as_usharp(
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w13)
layer.w13_weight.data = supa_w13_weight
# NOTE: w13_bias
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
wn = layer.intermediate_size_per_partition * 2
supa_w13_bias = torch_br._empty_ut_only(
size=(layer.local_num_experts, wn),
dtype=torch.float32,
is_numa=False,
device=cur_device,
tensor_type="linear_bias",
sbp="BB" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w13_bias = layer.w13_bias[expert_id]
# swigluoai activation, no need do interweave
if layer.activation and layer.activation == "swigluoai":
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(expert_w13_bias)
else:
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
crossed_expert_w13_bias = cross_weight_32(
expert_1_bias,
expert_3_bias,
die_spc_num,
dim=0,
need_pad=False,
)
narrow_data = supa_w13_bias[expert_id]
narrow_data.copy_(crossed_expert_w13_bias)
layer.w13_bias.data = supa_w13_bias
# NOTE: w2_weight
# after _load_w2, w2_weight is a rowparallel weight, shape
# [num_experts, hidden_size, intermediate_size_per_partition]
# for SUPA, transform it to a NUMA colmajor weight, shape
# [spc_num * num_experts, wk, wn_block]
align_size = 32
wk = layer.intermediate_size_per_partition
wn_block = align_n(layer.hidden_size,
align_size=align_size,
spc_num=spc_num)
supa_w2_weight = torch_br._empty_ut_only(
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
dtype=torch.bfloat16,
is_numa=True,
device=cur_device,
tensor_type="colmajor",
axis=0,
sbp="SS" if is_dual_die else None)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
align_size,
'COLMAJOR',
expert_w2.dtype,
parallel_type="row_parallel")
pad_expert_w2_shape = pad_expert_w2.shape
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
pad_expert_w2_shape,
Sbp.ss(0),
expert_id * hw_size)
narrow_data.copy_(pad_expert_w2)
layer.w2_weight.data = supa_w2_weight
# NOTE: w2_bias
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
wn = layer.hidden_size
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
dtype=torch.float32,
device=cur_device)
for expert_id in range(layer.local_num_experts):
expert_w2 = layer.w2_bias[expert_id]
narrow_data = supa_w2_bias[expert_id]
narrow_data.copy_(expert_w2)
layer.w2_bias.data = supa_w2_bias
@patch_to(FusedMoE)
def forward(self: FusedMoE, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
"""
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
shared_experts.down_proj weights.
"""
assert self.quant_method is not None
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.global_num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
tp_size = get_tensor_model_parallel_world_size()
if hidden_states.shape[
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
final_hidden_states.all_reduced = True
return final_hidden_states
@patch_to(FusedMoE)
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
loaded_weight: torch.Tensor, tp_rank: int):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
if shard_id == "w1":
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
# w3, up_proj: Load into second logical weight of w13.
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight.cpu())
@patch_to(FusedMoE)
def _load_w2(self,
expert_data: torch.Tensor,
shard_dim: int,
loaded_weight: torch.Tensor,
tp_rank: int,
load_full: bool = False):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight.cpu())
def wrapper_FusedMoE_init(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if self.e_score_correction_bias is not None:
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
)
return wrapper
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501