414 lines
17 KiB
Python
414 lines
17 KiB
Python
|
|
################################################################################
|
||
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
#
|
||
|
|
################################################################################
|
||
|
|
|
||
|
|
from functools import wraps
|
||
|
|
from typing import Callable, Optional
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch_br
|
||
|
|
from fastcore.basics import patch_to
|
||
|
|
from torch_br.utils.tensor_methods import Sbp
|
||
|
|
|
||
|
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||
|
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||
|
|
FusedMoE, UnquantizedFusedMoEMethod)
|
||
|
|
from vllm.model_executor.utils import set_weight_attrs
|
||
|
|
from vllm_br import envs
|
||
|
|
from ..br_utils import (_convert_to_crossed_numa_tensor,
|
||
|
|
_convert_to_numa_tensor, align_n, cross_weight_32)
|
||
|
|
from .supa_moe import (fused_moe_quant_device, fused_moe_quant_dyn,
|
||
|
|
fused_oss_moe_dyn)
|
||
|
|
|
||
|
|
|
||
|
|
@patch_to(UnquantizedFusedMoEMethod)
|
||
|
|
def forward_oot(
|
||
|
|
self,
|
||
|
|
layer: FusedMoE,
|
||
|
|
x: torch.Tensor,
|
||
|
|
use_grouped_topk: bool,
|
||
|
|
top_k: int,
|
||
|
|
router_logits: torch.Tensor,
|
||
|
|
renormalize: bool,
|
||
|
|
topk_group: Optional[int] = None,
|
||
|
|
num_expert_group: Optional[int] = None,
|
||
|
|
global_num_experts: int = -1,
|
||
|
|
expert_map: Optional[torch.Tensor] = None,
|
||
|
|
custom_routing_function: Optional[Callable] = None,
|
||
|
|
scoring_func: str = "softmax",
|
||
|
|
routed_scaling_factor: float = 1.0,
|
||
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||
|
|
apply_router_weight_on_input: bool = False,
|
||
|
|
activation: str = "silu",
|
||
|
|
enable_eplb: bool = False,
|
||
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
||
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||
|
|
logical_replica_count: Optional[torch.Tensor] = None,
|
||
|
|
):
|
||
|
|
"""Forward for UnquantizedFusedMoEMethod with SUPA out-of-tree support.
|
||
|
|
"""
|
||
|
|
if activation == "swigluoai":
|
||
|
|
return fused_oss_moe_dyn(
|
||
|
|
x,
|
||
|
|
layer.w13_weight,
|
||
|
|
layer.w13_bias,
|
||
|
|
layer.w2_weight,
|
||
|
|
layer.w2_bias,
|
||
|
|
router_logits,
|
||
|
|
top_k,
|
||
|
|
layer.intermediate_size_per_partition,
|
||
|
|
renormalize=renormalize,
|
||
|
|
inplace=True,
|
||
|
|
use_grouped_topk=use_grouped_topk,
|
||
|
|
topk_group=topk_group,
|
||
|
|
num_expert_group=num_expert_group,
|
||
|
|
custom_routing_function=custom_routing_function,
|
||
|
|
scoring_func=scoring_func,
|
||
|
|
e_score_correction_bias=e_score_correction_bias,
|
||
|
|
ep_rank=layer.ep_rank,
|
||
|
|
ep_size=layer.ep_size)
|
||
|
|
|
||
|
|
b_seq = x.shape[0]
|
||
|
|
gating_weight, shared_gate_up_weight, shared_down_weight = router_logits
|
||
|
|
if b_seq > envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN:
|
||
|
|
# prefill
|
||
|
|
return fused_moe_quant_dyn(
|
||
|
|
x,
|
||
|
|
shared_gate_up_weight,
|
||
|
|
shared_down_weight,
|
||
|
|
layer.w13_weight,
|
||
|
|
layer.w2_weight,
|
||
|
|
None,
|
||
|
|
None,
|
||
|
|
gating_weight,
|
||
|
|
top_k,
|
||
|
|
layer.intermediate_size_per_partition,
|
||
|
|
renormalize=renormalize,
|
||
|
|
inplace=True,
|
||
|
|
use_grouped_topk=use_grouped_topk,
|
||
|
|
topk_group=topk_group,
|
||
|
|
num_expert_group=num_expert_group,
|
||
|
|
custom_routing_function=custom_routing_function,
|
||
|
|
scoring_func=scoring_func,
|
||
|
|
e_score_correction_bias=e_score_correction_bias,
|
||
|
|
tp_rank=get_tp_group().rank_in_group,
|
||
|
|
global_rank=get_tp_group().rank,
|
||
|
|
tp_size=get_tensor_model_parallel_world_size(),
|
||
|
|
ep_rank=layer.ep_rank,
|
||
|
|
ep_size=layer.ep_size)
|
||
|
|
else:
|
||
|
|
# decoder
|
||
|
|
return fused_moe_quant_device(
|
||
|
|
x,
|
||
|
|
shared_gate_up_weight,
|
||
|
|
shared_down_weight,
|
||
|
|
layer.w13_weight,
|
||
|
|
layer.w2_weight,
|
||
|
|
None,
|
||
|
|
None,
|
||
|
|
gating_weight,
|
||
|
|
top_k,
|
||
|
|
layer.intermediate_size_per_partition,
|
||
|
|
renormalize=renormalize,
|
||
|
|
inplace=True,
|
||
|
|
use_grouped_topk=use_grouped_topk,
|
||
|
|
topk_group=topk_group,
|
||
|
|
num_expert_group=num_expert_group,
|
||
|
|
custom_routing_function=custom_routing_function,
|
||
|
|
scoring_func=scoring_func,
|
||
|
|
e_score_correction_bias=e_score_correction_bias,
|
||
|
|
tp_rank=get_tp_group().rank_in_group,
|
||
|
|
global_rank=get_tp_group().rank,
|
||
|
|
tp_size=get_tensor_model_parallel_world_size(),
|
||
|
|
ep_rank=layer.ep_rank,
|
||
|
|
ep_size=layer.ep_size)
|
||
|
|
|
||
|
|
|
||
|
|
@patch_to(UnquantizedFusedMoEMethod)
|
||
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||
|
|
hidden_size: int, intermediate_size_per_partition: int,
|
||
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
||
|
|
# Fused gate_up_proj (column parallel)
|
||
|
|
w13_weight = torch.nn.Parameter(torch.empty(
|
||
|
|
num_experts,
|
||
|
|
2 * intermediate_size_per_partition,
|
||
|
|
hidden_size,
|
||
|
|
device="cpu",
|
||
|
|
dtype=params_dtype),
|
||
|
|
requires_grad=False)
|
||
|
|
layer.register_parameter("w13_weight", w13_weight)
|
||
|
|
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||
|
|
|
||
|
|
if self.moe.has_bias:
|
||
|
|
w13_bias = torch.nn.Parameter(torch.zeros(
|
||
|
|
num_experts,
|
||
|
|
2 * intermediate_size_per_partition,
|
||
|
|
device="cpu",
|
||
|
|
dtype=params_dtype),
|
||
|
|
requires_grad=False)
|
||
|
|
layer.register_parameter("w13_bias", w13_bias)
|
||
|
|
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||
|
|
|
||
|
|
# down_proj (row parallel)
|
||
|
|
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
|
||
|
|
hidden_size,
|
||
|
|
intermediate_size_per_partition,
|
||
|
|
device="cpu",
|
||
|
|
dtype=params_dtype),
|
||
|
|
requires_grad=False)
|
||
|
|
layer.register_parameter("w2_weight", w2_weight)
|
||
|
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||
|
|
|
||
|
|
if self.moe.has_bias:
|
||
|
|
w2_bias = torch.nn.Parameter(torch.zeros(num_experts,
|
||
|
|
hidden_size,
|
||
|
|
device="cpu",
|
||
|
|
dtype=params_dtype),
|
||
|
|
requires_grad=False)
|
||
|
|
layer.register_parameter("w2_bias", w2_bias)
|
||
|
|
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||
|
|
|
||
|
|
|
||
|
|
@patch_to(UnquantizedFusedMoEMethod)
|
||
|
|
def process_weights_after_loading(self: UnquantizedFusedMoEMethod,
|
||
|
|
layer: FusedMoE) -> None:
|
||
|
|
cur_device = torch.supa.current_device()
|
||
|
|
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
||
|
|
die_num = 1 if die_spc_num <= 16 else 2
|
||
|
|
spc_num = die_spc_num // die_num
|
||
|
|
align_size = 32 if layer.activation == "swigluoai" else 64
|
||
|
|
is_dual_die = (die_spc_num > 16)
|
||
|
|
|
||
|
|
# NOTE: w13_weight
|
||
|
|
# after _load_w13, w13_weight is a colparallel weight, shape
|
||
|
|
# [num_experts, 2 * intermediate_size_per_partition, hidden_size]
|
||
|
|
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||
|
|
# [spc_num * num_experts, wk, wn_block] (wn = aligned(2 * intermediate_size_per_partition, align_size=64))
|
||
|
|
wk = layer.hidden_size
|
||
|
|
wn_block = align_n((layer.intermediate_size_per_partition * 2) // die_num,
|
||
|
|
align_size=align_size,
|
||
|
|
spc_num=spc_num)
|
||
|
|
|
||
|
|
supa_w13_weight = torch_br._empty_ut_only(
|
||
|
|
size=(die_spc_num * layer.local_num_experts, wk, wn_block),
|
||
|
|
dtype=torch.bfloat16,
|
||
|
|
is_numa=True,
|
||
|
|
device=cur_device,
|
||
|
|
tensor_type="colmajor",
|
||
|
|
axis=0,
|
||
|
|
sbp="SS" if is_dual_die else None)
|
||
|
|
|
||
|
|
for expert_id in range(layer.local_num_experts):
|
||
|
|
expert_w13 = layer.w13_weight[expert_id].transpose(0, 1).contiguous()
|
||
|
|
# swigluoai activation, no need do interweave
|
||
|
|
if layer.activation and layer.activation == "swigluoai":
|
||
|
|
pad_expert_w13 = _convert_to_numa_tensor(expert_w13, align_size,
|
||
|
|
'COLMAJOR',
|
||
|
|
expert_w13.dtype)
|
||
|
|
pad_expert_w13_shape = pad_expert_w13.shape
|
||
|
|
hw_size = pad_expert_w13_shape[-2] * pad_expert_w13_shape[-1]
|
||
|
|
narrow_data = supa_w13_weight.view_as_usharp(
|
||
|
|
"COLMAJOR", pad_expert_w13_shape, Sbp.ss(0),
|
||
|
|
expert_id * hw_size)
|
||
|
|
narrow_data.copy_(pad_expert_w13)
|
||
|
|
else:
|
||
|
|
expert_1, expert_3 = expert_w13.chunk(2, dim=1)
|
||
|
|
pad_expert_w13 = _convert_to_crossed_numa_tensor(expert_1,
|
||
|
|
expert_3,
|
||
|
|
die_spc_num,
|
||
|
|
dim=1,
|
||
|
|
need_pad=True,
|
||
|
|
layout='COLMAJOR')
|
||
|
|
hw_size = pad_expert_w13.shape[-2] * pad_expert_w13.shape[-1]
|
||
|
|
narrow_data = supa_w13_weight.view_as_usharp(
|
||
|
|
"COLMAJOR", pad_expert_w13.shape, Sbp.ss(0),
|
||
|
|
expert_id * hw_size)
|
||
|
|
narrow_data.copy_(pad_expert_w13)
|
||
|
|
|
||
|
|
layer.w13_weight.data = supa_w13_weight
|
||
|
|
|
||
|
|
# NOTE: w13_bias
|
||
|
|
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||
|
|
wn = layer.intermediate_size_per_partition * 2
|
||
|
|
supa_w13_bias = torch_br._empty_ut_only(
|
||
|
|
size=(layer.local_num_experts, wn),
|
||
|
|
dtype=torch.float32,
|
||
|
|
is_numa=False,
|
||
|
|
device=cur_device,
|
||
|
|
tensor_type="linear_bias",
|
||
|
|
sbp="BB" if is_dual_die else None)
|
||
|
|
for expert_id in range(layer.local_num_experts):
|
||
|
|
expert_w13_bias = layer.w13_bias[expert_id]
|
||
|
|
# swigluoai activation, no need do interweave
|
||
|
|
if layer.activation and layer.activation == "swigluoai":
|
||
|
|
narrow_data = supa_w13_bias[expert_id]
|
||
|
|
narrow_data.copy_(expert_w13_bias)
|
||
|
|
else:
|
||
|
|
expert_1_bias, expert_3_bias = expert_w13_bias.chunk(2, dim=-1)
|
||
|
|
crossed_expert_w13_bias = cross_weight_32(
|
||
|
|
expert_1_bias,
|
||
|
|
expert_3_bias,
|
||
|
|
die_spc_num,
|
||
|
|
dim=0,
|
||
|
|
need_pad=False,
|
||
|
|
)
|
||
|
|
narrow_data = supa_w13_bias[expert_id]
|
||
|
|
narrow_data.copy_(crossed_expert_w13_bias)
|
||
|
|
layer.w13_bias.data = supa_w13_bias
|
||
|
|
|
||
|
|
# NOTE: w2_weight
|
||
|
|
# after _load_w2, w2_weight is a rowparallel weight, shape
|
||
|
|
# [num_experts, hidden_size, intermediate_size_per_partition]
|
||
|
|
# for SUPA, transform it to a NUMA colmajor weight, shape
|
||
|
|
# [spc_num * num_experts, wk, wn_block]
|
||
|
|
align_size = 32
|
||
|
|
wk = layer.intermediate_size_per_partition
|
||
|
|
wn_block = align_n(layer.hidden_size,
|
||
|
|
align_size=align_size,
|
||
|
|
spc_num=spc_num)
|
||
|
|
|
||
|
|
supa_w2_weight = torch_br._empty_ut_only(
|
||
|
|
size=(die_spc_num * layer.local_num_experts, wk // die_num, wn_block),
|
||
|
|
dtype=torch.bfloat16,
|
||
|
|
is_numa=True,
|
||
|
|
device=cur_device,
|
||
|
|
tensor_type="colmajor",
|
||
|
|
axis=0,
|
||
|
|
sbp="SS" if is_dual_die else None)
|
||
|
|
|
||
|
|
for expert_id in range(layer.local_num_experts):
|
||
|
|
expert_w2 = layer.w2_weight[expert_id].transpose(0, 1).contiguous()
|
||
|
|
pad_expert_w2 = _convert_to_numa_tensor(expert_w2,
|
||
|
|
align_size,
|
||
|
|
'COLMAJOR',
|
||
|
|
expert_w2.dtype,
|
||
|
|
parallel_type="row_parallel")
|
||
|
|
pad_expert_w2_shape = pad_expert_w2.shape
|
||
|
|
hw_size = pad_expert_w2_shape[-2] * pad_expert_w2_shape[-1]
|
||
|
|
narrow_data = supa_w2_weight.view_as_usharp("COLMAJOR",
|
||
|
|
pad_expert_w2_shape,
|
||
|
|
Sbp.ss(0),
|
||
|
|
expert_id * hw_size)
|
||
|
|
narrow_data.copy_(pad_expert_w2)
|
||
|
|
|
||
|
|
layer.w2_weight.data = supa_w2_weight
|
||
|
|
|
||
|
|
# NOTE: w2_bias
|
||
|
|
if hasattr(layer, "w2_bias") and layer.w2_bias is not None:
|
||
|
|
wn = layer.hidden_size
|
||
|
|
supa_w2_bias = torch.zeros((layer.local_num_experts, wn),
|
||
|
|
dtype=torch.float32,
|
||
|
|
device=cur_device)
|
||
|
|
for expert_id in range(layer.local_num_experts):
|
||
|
|
expert_w2 = layer.w2_bias[expert_id]
|
||
|
|
narrow_data = supa_w2_bias[expert_id]
|
||
|
|
narrow_data.copy_(expert_w2)
|
||
|
|
|
||
|
|
layer.w2_bias.data = supa_w2_bias
|
||
|
|
|
||
|
|
|
||
|
|
@patch_to(FusedMoE)
|
||
|
|
def forward(self: FusedMoE, hidden_states: torch.Tensor,
|
||
|
|
router_logits: torch.Tensor):
|
||
|
|
"""
|
||
|
|
! router_logits is a tuple of gate, shared_experts.gate_up_proj,
|
||
|
|
shared_experts.down_proj weights.
|
||
|
|
"""
|
||
|
|
assert self.quant_method is not None
|
||
|
|
assert self.dp_size == 1, 'dp_size > 1 is not supported for now, please refer v0.11.0 moe codes'
|
||
|
|
|
||
|
|
final_hidden_states = self.quant_method.apply(
|
||
|
|
layer=self,
|
||
|
|
x=hidden_states,
|
||
|
|
router_logits=router_logits,
|
||
|
|
top_k=self.top_k,
|
||
|
|
renormalize=self.renormalize,
|
||
|
|
use_grouped_topk=self.use_grouped_topk,
|
||
|
|
global_num_experts=self.global_num_experts,
|
||
|
|
expert_map=self.expert_map,
|
||
|
|
topk_group=self.topk_group,
|
||
|
|
num_expert_group=self.num_expert_group,
|
||
|
|
custom_routing_function=self.custom_routing_function,
|
||
|
|
scoring_func=self.scoring_func,
|
||
|
|
e_score_correction_bias=self.e_score_correction_bias,
|
||
|
|
activation=self.activation,
|
||
|
|
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||
|
|
)
|
||
|
|
|
||
|
|
# NOTE: if using supa-moe-ccl kernel, add property `all_reduced` to the final_hidden_states
|
||
|
|
support_types = ((16, 4), (16, 8), (32, 2), (32, 4))
|
||
|
|
tp_size = get_tensor_model_parallel_world_size()
|
||
|
|
if hidden_states.shape[
|
||
|
|
0] <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and envs.VLLM_BR_QUANT_METHOD != "INT4" and envs.VLLM_BR_USE_FUSED_ALLREDUCE and (
|
||
|
|
envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types:
|
||
|
|
final_hidden_states.all_reduced = True
|
||
|
|
|
||
|
|
return final_hidden_states
|
||
|
|
|
||
|
|
|
||
|
|
@patch_to(FusedMoE)
|
||
|
|
def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, shard_id: str,
|
||
|
|
loaded_weight: torch.Tensor, tp_rank: int):
|
||
|
|
|
||
|
|
# Index the loaded weight for tp sharding.
|
||
|
|
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
|
||
|
|
shard_size = expert_data.shape[shard_dim] // 2
|
||
|
|
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||
|
|
shard_size)
|
||
|
|
# Narrow parameter and load.
|
||
|
|
# w1, gate_proj: Load into first logical weight of w13.
|
||
|
|
if shard_id == "w1":
|
||
|
|
expert_data = expert_data.narrow(shard_dim, 0, shard_size)
|
||
|
|
# w3, up_proj: Load into second logical weight of w13.
|
||
|
|
else:
|
||
|
|
assert shard_id == "w3"
|
||
|
|
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
|
||
|
|
expert_data.copy_(loaded_weight.cpu())
|
||
|
|
|
||
|
|
|
||
|
|
@patch_to(FusedMoE)
|
||
|
|
def _load_w2(self,
|
||
|
|
expert_data: torch.Tensor,
|
||
|
|
shard_dim: int,
|
||
|
|
loaded_weight: torch.Tensor,
|
||
|
|
tp_rank: int,
|
||
|
|
load_full: bool = False):
|
||
|
|
|
||
|
|
# Index the loaded weight for tp sharding.
|
||
|
|
# down_proj: "RowParallel" so tp sharding on input_dim
|
||
|
|
# Narrow parameter and load.
|
||
|
|
shard_size = expert_data.shape[shard_dim]
|
||
|
|
if not load_full:
|
||
|
|
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
|
||
|
|
shard_size)
|
||
|
|
# w2, down_proj: Load into only logical weight of w2.
|
||
|
|
expert_data.copy_(loaded_weight.cpu())
|
||
|
|
|
||
|
|
|
||
|
|
def wrapper_FusedMoE_init(fn):
|
||
|
|
|
||
|
|
@wraps(fn)
|
||
|
|
def wrapper(self, *args, **kwargs):
|
||
|
|
fn(self, *args, **kwargs)
|
||
|
|
|
||
|
|
if self.e_score_correction_bias is not None:
|
||
|
|
self.e_score_correction_bias.data = self.e_score_correction_bias.float(
|
||
|
|
)
|
||
|
|
|
||
|
|
return wrapper
|
||
|
|
|
||
|
|
|
||
|
|
FusedMoE.__init__ = wrapper_FusedMoE_init(FusedMoE.__init__) # noqa: E501
|