first commit
This commit is contained in:
70
vllm_br/model_executor/parameter.py
Normal file
70
vllm_br/model_executor/parameter.py
Normal file
@@ -0,0 +1,70 @@
|
||||
################################################################################
|
||||
# Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
################################################################################
|
||||
|
||||
import torch
|
||||
from fastcore.basics import patch_to
|
||||
|
||||
from vllm.model_executor.parameter import (PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
_ColumnvLLMParameter)
|
||||
from vllm_br import envs
|
||||
|
||||
|
||||
@patch_to(_ColumnvLLMParameter)
|
||||
def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
|
||||
|
||||
shard_offset = kwargs.get("shard_offset")
|
||||
shard_size = kwargs.get("shard_size")
|
||||
shard_id = kwargs.get("shard_id")
|
||||
num_heads = kwargs.get("num_heads")
|
||||
|
||||
# TODO: move these to PackedColumnParameter and PackedvLLMParameter
|
||||
if isinstance(
|
||||
self,
|
||||
(PackedColumnParameter,
|
||||
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
|
||||
param_data = self.data
|
||||
shard_id = (self.tp_rank if shard_id == "q" else self.tp_rank // num_heads)
|
||||
loaded_weight = loaded_weight.narrow(self.output_dim,
|
||||
shard_id * shard_size, shard_size)
|
||||
|
||||
if envs.VLLM_BR_DEVICE_SPC_NUM > 16:
|
||||
assert isinstance(shard_size,
|
||||
int), "failed to check shard_size type is int"
|
||||
assert isinstance(shard_offset,
|
||||
int), "failed to check shard_offset type is int"
|
||||
half_w = param_data.shape[self.output_dim] // 2
|
||||
half_shard_size = shard_size // 2
|
||||
half_shard_offset = shard_offset // 2
|
||||
param_data_0 = param_data.narrow(self.output_dim, half_shard_offset,
|
||||
half_shard_size)
|
||||
param_data_1 = param_data.narrow(self.output_dim,
|
||||
half_shard_offset + half_w,
|
||||
half_shard_size)
|
||||
param_data_0.copy_(
|
||||
loaded_weight.narrow(self.output_dim, 0, half_shard_size))
|
||||
param_data_1.copy_(
|
||||
loaded_weight.narrow(self.output_dim, half_shard_size,
|
||||
half_shard_size))
|
||||
else:
|
||||
param_data = param_data.narrow(self.output_dim, shard_offset,
|
||||
shard_size)
|
||||
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
Reference in New Issue
Block a user