Files
2026-03-10 13:31:25 +08:00

71 lines
3.0 KiB
Python

################################################################################
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
################################################################################
import torch
from fastcore.basics import patch_to
from vllm.model_executor.parameter import (PackedColumnParameter,
PackedvLLMParameter,
_ColumnvLLMParameter)
from vllm_br import envs
@patch_to(_ColumnvLLMParameter)
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_id = kwargs.get("shard_id")
num_heads = kwargs.get("num_heads")
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
if isinstance(
self,
(PackedColumnParameter,
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
param_data = self.data
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank // num_heads)
loaded_weight = loaded_weight.narrow(self.output_dim,
shard_id * shard_size, shard_size)
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
assert isinstance(shard_size,
int), "failed to check shard_size type is int"
assert isinstance(shard_offset,
int), "failed to check shard_offset type is int"
half_w = param_data.shape[self.output_dim] // 2
half_shard_size = shard_size // 2
half_shard_offset = shard_offset // 2
param_data_0 = param_data.narrow(self.output_dim, half_shard_offset,
half_shard_size)
param_data_1 = param_data.narrow(self.output_dim,
half_shard_offset + half_w,
half_shard_size)
param_data_0.copy_(
loaded_weight.narrow(self.output_dim, 0, half_shard_size))
param_data_1.copy_(
loaded_weight.narrow(self.output_dim, half_shard_size,
half_shard_size))
else:
param_data = param_data.narrow(self.output_dim, shard_offset,
shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)