768 lines
32 KiB
Python
768 lines
32 KiB
Python
################################################################################
|
|
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
################################################################################
|
|
|
|
from typing import Literal, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch_br
|
|
import torch_br.supa._debug as supa_debug
|
|
from fastcore.basics import patch_to
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from vllm.distributed import (get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
get_tp_group, split_tensor_along_last_dim,
|
|
tensor_model_parallel_all_reduce)
|
|
from vllm.distributed.parallel_state import get_pp_group
|
|
from vllm.logger import logger
|
|
from vllm.model_executor.layers.linear import (adjust_bitsandbytes_4bit_shard,
|
|
adjust_marlin_shard,
|
|
adjust_scalar_to_fused_array)
|
|
from vllm_br import envs
|
|
from vllm_br.utils import get_grandparent_pid
|
|
from .br_utils import (_convert_to_crossed_numa_tensor,
|
|
_convert_to_numa_tensor, _convert_to_numa_tensor_vit,
|
|
is_br166_device)
|
|
|
|
from vllm.model_executor.layers.linear import ( # isort:skip
|
|
LinearBase, MergedColumnParallelLinear, QuantizationConfig,
|
|
ReplicatedLinear, RowParallelLinear, UnquantizedLinearMethod,
|
|
QKVParallelLinear)
|
|
|
|
|
|
def _should_skip_linear_post_process(layer, use_ds_mla, use_ds_mla_sparse):
|
|
"""NOTE: SUPA: for MLA linears, we do process in MLA.process_weights_after_loading """
|
|
# TODO: Hard code for native dsa op
|
|
if use_ds_mla_sparse:
|
|
MLA_LINEAR_NAMES = [
|
|
"kv_b_proj",
|
|
]
|
|
else:
|
|
MLA_LINEAR_NAMES = [
|
|
"q_a_proj",
|
|
"q_b_proj",
|
|
# "q_proj",
|
|
"kv_a_proj_with_mqa",
|
|
"kv_b_proj",
|
|
# "o_proj",
|
|
]
|
|
if use_ds_mla and not use_ds_mla_sparse:
|
|
MLA_LINEAR_NAMES.append("o_proj")
|
|
|
|
skip = any(k in layer.prefix for k in MLA_LINEAR_NAMES)
|
|
if skip:
|
|
logger.debug(
|
|
f'[SUPA] skip {layer.prefix} UnquantizedLinearMethod.process_weights_after_loading' # noqa: G004
|
|
)
|
|
return skip
|
|
|
|
|
|
# NOTE: ReplicatedLinear, usually used in MoE as a gate module.
|
|
# In DeepseekV3, it needs to be transposed.
|
|
def process_weights_ReplicatedLinear(
|
|
layer: ReplicatedLinear) -> Literal[True, False]:
|
|
layer.weight.data = layer.weight.data.transpose(1, 0).contiguous()
|
|
return True
|
|
|
|
|
|
def process_share_expert_weight(layer: MergedColumnParallelLinear):
|
|
gate_up_weight = layer.weight.transpose(1, 0).contiguous()
|
|
|
|
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
|
|
|
|
die_spc_num = envs.VLLM_BR_DEVICE_SPC_NUM
|
|
is_br166 = die_spc_num > 16
|
|
spc_num = die_spc_num // 2 if is_br166 else die_spc_num
|
|
|
|
if is_br166:
|
|
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
|
|
spc_for_shared = 2 if spc_num == 4 else 8
|
|
spc_for_router = spc_num - spc_for_shared
|
|
|
|
align_size = 32
|
|
weight_dtype = gate_weight.dtype
|
|
hidden_size = gate_weight.shape[0]
|
|
|
|
gate_d0, gate_d1 = torch.chunk(gate_weight, 2, dim=-1)
|
|
up_d0, up_d1 = torch.chunk(up_weight, 2, dim=-1)
|
|
im_size = gate_d0.shape[-1]
|
|
n_align_size = (align_size * 2) * spc_for_shared
|
|
swiglu_w_aligned = ((
|
|
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
|
|
region_size = swiglu_w_aligned // spc_for_shared
|
|
block_nums = (region_size // (align_size * 2)) * spc_for_shared
|
|
|
|
gate_d0_align = torch.nn.functional.pad(
|
|
gate_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
gate_d1_align = torch.nn.functional.pad(
|
|
gate_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
up_d0_align = torch.nn.functional.pad(
|
|
up_d0, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
up_d1_align = torch.nn.functional.pad(
|
|
up_d1, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
gate_weight_d0_reshape = gate_d0_align.reshape(
|
|
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
|
gate_weight_d1_reshape = gate_d1_align.reshape(
|
|
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
|
up_weight_d0_reshape = up_d0_align.reshape(
|
|
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
|
up_weight_d1_reshape = up_d1_align.reshape(
|
|
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
|
|
|
gate_up_weight_d0 = torch.zeros(
|
|
[hidden_size, block_nums, align_size * 2],
|
|
dtype=weight_dtype,
|
|
device='supa')
|
|
|
|
gate_up_weight_d0[:, :, 0:0 +
|
|
align_size] = gate_weight_d0_reshape[:, :,
|
|
0:align_size]
|
|
|
|
gate_up_weight_d0[:, :, align_size:align_size *
|
|
2] = up_weight_d0_reshape[:, :, 0:align_size]
|
|
gate_up_weight_d0 = gate_up_weight_d0.reshape(
|
|
hidden_size, spc_for_shared,
|
|
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
|
|
|
gate_up_d0_invalid = torch.zeros(
|
|
[spc_for_router, hidden_size, region_size],
|
|
dtype=weight_dtype,
|
|
device='supa') # invalid regions
|
|
gate_up_weight_d0_whole = torch.cat(
|
|
[gate_up_weight_d0, gate_up_d0_invalid], dim=0)
|
|
|
|
gate_up_weight_d1 = torch.zeros(
|
|
[hidden_size, block_nums, align_size * 2],
|
|
dtype=weight_dtype,
|
|
device='supa')
|
|
|
|
gate_up_weight_d1[:, :, 0:0 +
|
|
align_size] = gate_weight_d1_reshape[:, :,
|
|
0:align_size]
|
|
|
|
gate_up_weight_d1[:, :, align_size:align_size *
|
|
2] = up_weight_d1_reshape[:, :, 0:align_size]
|
|
gate_up_weight_d1 = gate_up_weight_d1.reshape(
|
|
hidden_size, spc_for_shared,
|
|
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
|
|
|
gate_up_d1_invalid = torch.zeros(
|
|
[spc_for_router, hidden_size, region_size],
|
|
dtype=weight_dtype,
|
|
device='supa') # invalid regions
|
|
gate_up_weight_d1_whole = torch.cat(
|
|
[gate_up_weight_d1, gate_up_d1_invalid], dim=0)
|
|
|
|
gate_up_weight_whole = torch.cat(
|
|
[gate_up_weight_d0_whole, gate_up_weight_d1_whole], dim=0)
|
|
gate_up_weight_supa = torch_br._empty_ut_only(
|
|
size=gate_up_weight_whole.shape,
|
|
dtype=gate_weight.dtype,
|
|
is_numa=True,
|
|
device="supa",
|
|
tensor_type="colmajor")
|
|
gate_up_weight_supa.copy_(gate_up_weight_whole)
|
|
|
|
layer.weight.data = gate_up_weight_supa
|
|
else:
|
|
# 2&2 for 4spc, 4&8 for 12spc, 8&8 for 16spc
|
|
spc_for_shared = 2 if spc_num == 4 else 8
|
|
spc_for_router = spc_num - spc_for_shared
|
|
|
|
align_size = 32
|
|
weight_dtype = gate_weight.dtype
|
|
hidden_size = gate_weight.shape[0]
|
|
im_size = gate_weight.shape[-1]
|
|
n_align_size = (align_size * 2) * spc_for_shared
|
|
swiglu_w_aligned = ((
|
|
(im_size * 2) + n_align_size - 1) // n_align_size) * n_align_size
|
|
region_size = swiglu_w_aligned // spc_for_shared
|
|
block_nums = (region_size // (align_size * 2)) * spc_for_shared
|
|
|
|
gate_golden_align = torch.nn.functional.pad(
|
|
gate_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
up_golden_align = torch.nn.functional.pad(
|
|
up_weight, (0, (swiglu_w_aligned // 2) - im_size, 0, 0),
|
|
mode='constant',
|
|
value=0)
|
|
gate_weight_golden_reshape = gate_golden_align.reshape(
|
|
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
|
up_weight_golden_reshape = up_golden_align.reshape(
|
|
hidden_size, block_nums, align_size).contiguous().to(weight_dtype)
|
|
|
|
gate_up_weight_golden = torch.zeros(
|
|
[hidden_size, block_nums, align_size * 2],
|
|
dtype=weight_dtype,
|
|
device='supa')
|
|
|
|
gate_up_weight_golden[:, :, 0:0 +
|
|
align_size] = gate_weight_golden_reshape[:, :, 0:
|
|
align_size]
|
|
|
|
gate_up_weight_golden[:, :, align_size:align_size *
|
|
2] = up_weight_golden_reshape[:, :, 0:align_size]
|
|
gate_up_weight_golden = gate_up_weight_golden.reshape(
|
|
hidden_size, spc_for_shared,
|
|
region_size).permute(1, 0, 2).contiguous().to(weight_dtype)
|
|
|
|
gate_up_invalid = torch.zeros(
|
|
[spc_for_router, hidden_size, region_size],
|
|
dtype=weight_dtype,
|
|
device='supa') # invalid regions
|
|
gate_up_weight_whole = torch.cat(
|
|
[gate_up_weight_golden, gate_up_invalid], dim=0)
|
|
|
|
gate_up_weight_supa = torch_br._empty_ut_only(
|
|
size=gate_up_weight_whole.shape,
|
|
dtype=gate_weight.dtype,
|
|
is_numa=True,
|
|
device="supa",
|
|
tensor_type="colmajor")
|
|
gate_up_weight_supa.copy_(gate_up_weight_whole)
|
|
|
|
layer.weight.data = gate_up_weight_supa
|
|
|
|
|
|
# NOTE: MergedColumnParallelLinear, usually used in MergedGateUpMLPSiluL2
|
|
def process_weights_QuantMergedColumnParallelLinear(
|
|
layer: MergedColumnParallelLinear):
|
|
if 'shared_experts' not in layer.prefix:
|
|
#NOTE: normal MLP gate_up, after load weight, convert to supa numa tensor
|
|
if hasattr(layer, "qweight"):
|
|
gate_weight, up_weight = torch.chunk(layer.qweight, 2, dim=-1)
|
|
gate_up_weight_numa = _convert_to_crossed_numa_tensor(
|
|
gate_weight,
|
|
up_weight,
|
|
envs.VLLM_BR_DEVICE_SPC_NUM,
|
|
dim=-1,
|
|
need_pad=True,
|
|
do_transpose=False)
|
|
layer.qweight.data = gate_up_weight_numa
|
|
else:
|
|
gate_up_weight = layer.weight.permute(1, 0).contiguous()
|
|
gate_up_weight_numa = _convert_to_numa_tensor(
|
|
gate_up_weight,
|
|
32,
|
|
"colmajor",
|
|
gate_up_weight.dtype,
|
|
False,
|
|
parallel_type="col_parallel")
|
|
layer.weight.data = gate_up_weight_numa
|
|
|
|
if hasattr(layer, "scales") and layer.scales is not None:
|
|
gate_scales, up_scales = torch.chunk(layer.scales, 2, dim=-1)
|
|
gate_up_scales_internleaved_numa = _convert_to_crossed_numa_tensor(
|
|
gate_scales,
|
|
up_scales,
|
|
envs.VLLM_BR_DEVICE_SPC_NUM,
|
|
dim=-1,
|
|
need_pad=False,
|
|
layout="linear_bias",
|
|
do_transpose=False)
|
|
layer.scales.data = gate_up_scales_internleaved_numa
|
|
|
|
if hasattr(layer, "bias") and layer.bias is not None:
|
|
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
|
|
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
|
|
gate_bias,
|
|
up_bias,
|
|
envs.VLLM_BR_DEVICE_SPC_NUM,
|
|
dim=-1,
|
|
need_pad=False,
|
|
layout="linear_bias",
|
|
do_transpose=False)
|
|
layer.bias.data = gate_up_bias_internleaved_numa
|
|
else:
|
|
process_share_expert_weight(layer)
|
|
|
|
|
|
def process_weights_MergedColumnParallelLinear(
|
|
layer: MergedColumnParallelLinear):
|
|
if 'shared_experts' not in layer.prefix:
|
|
gate_up_weight = layer.weight.permute(1, 0).contiguous()
|
|
if not (hasattr(layer, "no_need_cross") and layer.no_need_cross):
|
|
gate_weight, up_weight = torch.chunk(gate_up_weight, 2, dim=-1)
|
|
gate_up_weight_internleaved_numa = _convert_to_crossed_numa_tensor(
|
|
gate_weight,
|
|
up_weight,
|
|
envs.VLLM_BR_DEVICE_SPC_NUM,
|
|
dim=-1,
|
|
need_pad=True,
|
|
do_transpose=False)
|
|
layer.weight.data = gate_up_weight_internleaved_numa
|
|
else:
|
|
gate_up_weight_numa = _convert_to_numa_tensor(
|
|
gate_up_weight,
|
|
align_size=32,
|
|
layout="colmajor",
|
|
dtype=gate_up_weight.dtype,
|
|
do_transpose=False)
|
|
layer.weight.data = gate_up_weight_numa
|
|
|
|
if hasattr(layer, "bias") and layer.bias is not None:
|
|
gate_bias, up_bias = torch.chunk(layer.bias, 2, dim=-1)
|
|
gate_up_bias_internleaved_numa = _convert_to_crossed_numa_tensor(
|
|
gate_bias,
|
|
up_bias,
|
|
envs.VLLM_BR_DEVICE_SPC_NUM,
|
|
dim=-1,
|
|
need_pad=False,
|
|
layout="linear_bias",
|
|
do_transpose=False)
|
|
layer.bias.data = gate_up_bias_internleaved_numa
|
|
|
|
else:
|
|
#NOTE: by default, gate module and shared_expert(1) module will be involved into calculation in 1 kernel
|
|
process_share_expert_weight(layer)
|
|
|
|
|
|
@patch_to(UnquantizedLinearMethod)
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
if _should_skip_linear_post_process(
|
|
layer, self.use_ds_mla,
|
|
self.use_ds_mla_sparse) or self.weight_type != "NUMA":
|
|
return
|
|
still_need_process = True
|
|
do_transpose = True
|
|
parallel_type = "col_parallel"
|
|
# NOTE: all process_weights func should done before process_weights_after_loading
|
|
match layer:
|
|
case ReplicatedLinear():
|
|
process_weights_ReplicatedLinear(layer)
|
|
still_need_process = not ("indexer" not in layer.prefix and (
|
|
layer.output_size == 64 or layer.output_size == 160 # Glm4-Moe
|
|
or layer.output_size == 128 or layer.output_size == 256))
|
|
do_transpose = False
|
|
case MergedColumnParallelLinear():
|
|
process_weights_MergedColumnParallelLinear(layer)
|
|
still_need_process = False
|
|
do_transpose = False
|
|
case RowParallelLinear():
|
|
parallel_type = "row_parallel"
|
|
case _:
|
|
pass
|
|
|
|
if not still_need_process or self.weight_type != "NUMA":
|
|
return
|
|
|
|
# process numa weight and bias
|
|
if hasattr(layer, "weight") and len(layer.weight.shape) == 2:
|
|
if 'vision' in layer.prefix and is_br166_device():
|
|
layer.weight.data = _convert_to_numa_tensor_vit(
|
|
layer.weight,
|
|
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
|
"colmajor",
|
|
torch.bfloat16,
|
|
do_transpose=do_transpose,
|
|
wk=(layer.weight.data.shape[1]
|
|
if do_transpose else layer.weight.data.shape[0]),
|
|
wn=(layer.weight.data.shape[0]
|
|
if do_transpose else layer.weight.data.shape[1]),
|
|
parallel_type=parallel_type) # noqa: SIM210
|
|
else:
|
|
layer.weight.data = _convert_to_numa_tensor(
|
|
layer.weight,
|
|
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
|
"colmajor",
|
|
torch.bfloat16,
|
|
do_transpose=do_transpose,
|
|
wk=(layer.weight.data.shape[1]
|
|
if do_transpose else layer.weight.data.shape[0]),
|
|
wn=(layer.weight.data.shape[0]
|
|
if do_transpose else layer.weight.data.shape[1]),
|
|
parallel_type=parallel_type) # noqa: SIM210
|
|
|
|
if hasattr(layer, "bias") and layer.bias is not None:
|
|
pad_zeros = (parallel_type == "row_parallel")
|
|
if 'vision' in layer.prefix and is_br166_device():
|
|
if (pad_zeros and layer.reduce_results):
|
|
return
|
|
layer.bias.data = _convert_to_numa_tensor_vit(
|
|
layer.bias,
|
|
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
|
"linear_bias",
|
|
torch.float32,
|
|
parallel_type=parallel_type,
|
|
pad_zeros=pad_zeros)
|
|
else:
|
|
layer.bias.data = _convert_to_numa_tensor(
|
|
layer.bias,
|
|
envs.VLLM_BR_DEVICE_WARP_SIZE,
|
|
"linear_bias",
|
|
torch.float32,
|
|
parallel_type=parallel_type,
|
|
pad_zeros=pad_zeros)
|
|
|
|
|
|
@patch_to(UnquantizedLinearMethod)
|
|
def apply(self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
# numa weight is 3-dims
|
|
if 'vision' in layer.prefix and is_br166_device():
|
|
if len(layer.weight.shape) == 3:
|
|
is_row = isinstance(layer, RowParallelLinear)
|
|
output_size = (layer.output_size_per_partition if hasattr(
|
|
layer, "output_size_per_partition") else layer.output_size)
|
|
act_mode = "act_default"
|
|
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
|
|
layer, "no_need_cross") and layer.no_need_cross):
|
|
act_mode = "act_swiglu"
|
|
output_size //= 2
|
|
if bias is None or (is_row and layer.reduce_results):
|
|
# return torch_br.br_matmul_infer(
|
|
# x,
|
|
# layer.weight,
|
|
# bias=None,
|
|
# output_w=output_size,
|
|
# )
|
|
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
|
output_w=output_size,
|
|
activation_mode=act_mode)
|
|
else:
|
|
return torch_br.br_matmul_infer(x, layer.weight, bias,
|
|
output_size)
|
|
supa_debug.set_enable_sublas_api(True)
|
|
output = F.linear(x, layer.weight, bias)
|
|
supa_debug.set_enable_sublas_api(False)
|
|
return output
|
|
if len(layer.weight.shape) == 3:
|
|
output_size = (layer.output_size_per_partition if hasattr(
|
|
layer, "output_size_per_partition") else layer.output_size)
|
|
act_mode = "act_default"
|
|
if isinstance(layer, MergedColumnParallelLinear) and not (hasattr(
|
|
layer, "no_need_cross") and layer.no_need_cross):
|
|
act_mode = "act_swiglu"
|
|
output_size //= 2
|
|
|
|
bias = [bias] if bias is not None else None
|
|
if isinstance(layer, RowParallelLinear):
|
|
seq_len = x.shape[-2]
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
# TODO(CaoJun): This is WA, delete (16, 8) so that the test_vllm_model_accu_qwen25_72b_instruct can run through
|
|
support_types = ((16, 4), (32, 2), (32, 4))
|
|
# bypass tp8 and tp4pp2 allreduce
|
|
pp_size = get_pp_group().world_size
|
|
all_rank = tp_size * pp_size
|
|
layer.reduce_results = not (
|
|
all_rank <= envs.VLLM_BR_USE_FUSED_ALLREDUCE
|
|
and seq_len <= envs.VLLM_BR_STATIC_MOE_DECODER_MAX_LEN and
|
|
(envs.VLLM_BR_DEVICE_SPC_NUM, tp_size) in support_types)
|
|
if layer.reduce_results:
|
|
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
|
output_w=output_size,
|
|
bias=bias,
|
|
activation_mode=act_mode)
|
|
else:
|
|
tp_rank = get_tp_group().rank_in_group
|
|
global_rank = get_tp_group().rank
|
|
rank_i = global_rank % tp_size
|
|
assert rank_i == tp_rank
|
|
return torch_br.supa_fused_linear_allreduce_opt(
|
|
x, layer.weight, output_size, tp_rank, tp_size,
|
|
global_rank, 0)
|
|
else:
|
|
return torch_br.br_fused_mlp_infer(x, [layer.weight],
|
|
output_w=output_size,
|
|
bias=bias,
|
|
activation_mode=act_mode)
|
|
supa_debug.set_enable_sublas_api(True)
|
|
output = F.linear(x, layer.weight, bias)
|
|
supa_debug.set_enable_sublas_api(False)
|
|
return output
|
|
|
|
|
|
@patch_to(LinearBase)
|
|
def __init__(
|
|
self,
|
|
input_size: int,
|
|
output_size: int,
|
|
skip_bias_add: bool = False,
|
|
params_dtype: Optional[torch.dtype] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
*,
|
|
return_bias: bool = True,
|
|
disable_tp: bool = False,
|
|
):
|
|
super(LinearBase, self).__init__()
|
|
|
|
# Keep input parameters
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.skip_bias_add = skip_bias_add
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
self.params_dtype = params_dtype
|
|
if quant_config is None:
|
|
self.quant_method = UnquantizedLinearMethod()
|
|
else:
|
|
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
|
self.return_bias = return_bias
|
|
self.prefix = prefix
|
|
self.tp_rank = (get_tensor_model_parallel_rank() if not disable_tp else 0)
|
|
self.tp_size = (get_tensor_model_parallel_world_size()
|
|
if not disable_tp else 1)
|
|
|
|
|
|
@patch_to(RowParallelLinear)
|
|
def forward(
|
|
self, input_
|
|
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
|
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr(
|
|
self, "grandparent_pid"):
|
|
self.grandparent_pid = get_grandparent_pid()
|
|
if self.input_is_parallel:
|
|
input_parallel = input_
|
|
else:
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
splitted_input = split_tensor_along_last_dim(
|
|
input_, num_partitions=self.tp_size)
|
|
input_parallel = splitted_input[tp_rank].contiguous()
|
|
|
|
# Matrix multiply.
|
|
assert self.quant_method is not None
|
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
|
# bias will not get added more than once in TP>1 case)
|
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
|
if self.reduce_results and self.tp_size > 1:
|
|
# CPU all reduce will be applied.
|
|
if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and self.tp_size >= 4 and output_parallel.shape[
|
|
1] <= 32:
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
output = torch_br.supa_allreduce_pcie_infer(
|
|
output_parallel, tp_rank, self.tp_size, self.grandparent_pid)
|
|
else:
|
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
|
else:
|
|
output = output_parallel
|
|
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
|
|
if not self.return_bias:
|
|
return output
|
|
return output, output_bias
|
|
|
|
|
|
@patch_to(QKVParallelLinear)
|
|
def weight_loader(self,
|
|
param: Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
loaded_shard_id: Optional[str] = None):
|
|
|
|
# Special case for GGUF
|
|
# initialize GGUF param after we know the quantize type
|
|
is_gguf_weight = getattr(param, "is_gguf_weight", False)
|
|
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
|
|
if is_gguf_weight_type:
|
|
idx_map = {"q": 0, "k": 1, "v": 2}
|
|
if loaded_shard_id is not None:
|
|
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
|
|
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
|
|
else:
|
|
param.shard_weight_type = {
|
|
k: loaded_weight.item()
|
|
for k in idx_map
|
|
}
|
|
return
|
|
|
|
if is_gguf_weight:
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
|
|
output_dim = getattr(param, "output_dim", None)
|
|
shard_size = loaded_weight.size(output_dim) // tp_size
|
|
start_idx = tp_rank * shard_size
|
|
|
|
if loaded_shard_id is not None:
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
shard_size)
|
|
param.shard_id.append(loaded_shard_id)
|
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
|
param.data_container.append(loaded_weight)
|
|
return
|
|
|
|
param_data = param.data
|
|
output_dim = getattr(param, "output_dim", None)
|
|
# Special case for AQLM codebooks.
|
|
is_metadata = getattr(param, "is_metadata", False)
|
|
|
|
# Special case for per-tensor scales in fused case.
|
|
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
|
|
|
|
if loaded_shard_id is None:
|
|
# Loaded weight is already fused on disk (qkv).
|
|
# (e.g., Phi-3's qkv_proj).
|
|
if output_dim is None:
|
|
if needs_scalar_to_array:
|
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
param_data, loaded_weight, 0)
|
|
|
|
assert param_data.shape == loaded_weight.shape
|
|
param_data.copy_(loaded_weight)
|
|
return
|
|
shard_offsets = [
|
|
# (shard_id, shard_offset, shard_size)
|
|
("q", 0, self.total_num_heads * self.head_size),
|
|
("k", self.total_num_heads * self.head_size,
|
|
self.total_num_kv_heads * self.head_size),
|
|
("v",
|
|
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
|
|
self.total_num_kv_heads * self.head_size),
|
|
]
|
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
|
|
|
packed_dim = getattr(param, "packed_dim", None)
|
|
for shard_id, shard_offset, shard_size in shard_offsets:
|
|
# Special case for Quantized Weights.
|
|
# If quantized, we need to adjust the offset and size to account
|
|
# for the packing.
|
|
if packed_dim == output_dim:
|
|
shard_size = shard_size // param.pack_factor
|
|
shard_offset = shard_offset // param.pack_factor
|
|
|
|
# Special case for Marlin.
|
|
shard_size, shard_offset = adjust_marlin_shard(
|
|
param, shard_size, shard_offset)
|
|
|
|
if use_bitsandbytes_4bit:
|
|
orig_qkv_offsets = {
|
|
"q": (0, self.total_num_heads * self.head_size),
|
|
"k": (self.total_num_heads * self.head_size,
|
|
self.total_num_kv_heads * self.head_size),
|
|
"v":
|
|
((self.total_num_heads + self.total_num_kv_heads) *
|
|
self.head_size, self.total_num_kv_heads * self.head_size),
|
|
"total":
|
|
((self.total_num_heads + 2 * self.total_num_kv_heads) *
|
|
self.head_size, 0)
|
|
}
|
|
|
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
|
param, orig_qkv_offsets, shard_id)
|
|
|
|
loaded_weight_shard = loaded_weight.narrow(output_dim,
|
|
shard_offset,
|
|
shard_size)
|
|
self.weight_loader(param, loaded_weight_shard, shard_id)
|
|
return
|
|
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
assert loaded_shard_id in ["q", "k", "v"]
|
|
|
|
# If output dim is defined, use the default loading process.
|
|
if output_dim is not None:
|
|
if loaded_shard_id == "q":
|
|
shard_offset = 0
|
|
shard_size = self.num_heads * self.head_size
|
|
elif loaded_shard_id == "k":
|
|
shard_offset = self.num_heads * self.head_size
|
|
shard_size = self.num_kv_heads * self.head_size
|
|
elif loaded_shard_id == "v":
|
|
shard_offset = (self.num_heads +
|
|
self.num_kv_heads) * self.head_size
|
|
shard_size = self.num_kv_heads * self.head_size
|
|
# Special case for Quantized Weights.
|
|
# If quantized, we need to adjust the offset and size to account
|
|
# for the packing.
|
|
packed_dim = getattr(param, "packed_dim", None)
|
|
if packed_dim == output_dim:
|
|
shard_size = shard_size // param.pack_factor
|
|
shard_offset = shard_offset // param.pack_factor
|
|
|
|
# Special case for Marlin.
|
|
shard_size, shard_offset = adjust_marlin_shard(
|
|
param, shard_size, shard_offset)
|
|
|
|
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
|
|
is_sharded_weight = getattr(param, "is_sharded_weight", False)
|
|
# bitsandbytes loads the weights of the specific portion
|
|
# no need to narrow
|
|
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
|
|
|
|
if use_bitsandbytes_4bit:
|
|
orig_qkv_offsets = {
|
|
"q": (0, self.num_heads * self.head_size),
|
|
"k": (self.num_heads * self.head_size,
|
|
self.num_kv_heads * self.head_size),
|
|
"v": ((self.num_heads + self.num_kv_heads) * self.head_size,
|
|
self.num_kv_heads * self.head_size),
|
|
"total":
|
|
((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0)
|
|
}
|
|
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
|
|
param, orig_qkv_offsets, loaded_shard_id)
|
|
|
|
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
|
half_w = param_data.shape[output_dim] // 2
|
|
param_data = (param_data.narrow(output_dim, shard_offset // 2,
|
|
shard_size // 2),
|
|
param_data.narrow(output_dim,
|
|
shard_offset // 2 + half_w,
|
|
shard_size // 2))
|
|
else:
|
|
param_data = param_data.narrow(output_dim, shard_offset,
|
|
shard_size)
|
|
|
|
if loaded_shard_id == "q":
|
|
shard_id = tp_rank
|
|
else:
|
|
shard_id = tp_rank // self.num_kv_head_replicas
|
|
start_idx = shard_id * shard_size
|
|
|
|
if not is_sharded_weight:
|
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
|
shard_size)
|
|
|
|
# Special case for for AQLM codebooks.
|
|
elif is_metadata:
|
|
# metadata indicates fixed size concatenated along dim 0
|
|
shard_size = loaded_weight.shape[0]
|
|
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
|
param_data = param_data.narrow(0, shard_index * shard_size, shard_size)
|
|
# Special case for per-tensor scales in fused case.
|
|
elif needs_scalar_to_array:
|
|
param_data, loaded_weight = adjust_scalar_to_fused_array(
|
|
param_data, loaded_weight, loaded_shard_id)
|
|
else:
|
|
ignore_warning = getattr(param, "ignore_warning", False)
|
|
if not ignore_warning:
|
|
logger.warning(
|
|
"Loading a weight without `output_dim` attribute in "
|
|
"QKVParallelLinear, assume the weight is the same "
|
|
"for all partitions.")
|
|
|
|
if isinstance(param_data, tuple):
|
|
half_w = loaded_weight.shape[output_dim] // 2
|
|
param_data[0].copy_(loaded_weight.narrow(output_dim, 0, half_w))
|
|
param_data[1].copy_(loaded_weight.narrow(output_dim, half_w, half_w))
|
|
else:
|
|
assert param_data.shape == loaded_weight.shape
|
|
param_data.copy_(loaded_weight)
|