165 lines
5.5 KiB
Python
165 lines
5.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import torch
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config.lora import LoRAConfig
|
|
from vllm.distributed.utils import divide
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
LinearBase,
|
|
ReplicatedLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.platforms import current_platform
|
|
|
|
from .base import BaseLayerWithLoRA
|
|
from .utils import _get_lora_device
|
|
|
|
|
|
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
|
def __init__(self, base_layer: LinearBase):
|
|
super().__init__()
|
|
self.base_layer = base_layer
|
|
self.input_size = self.base_layer.input_size
|
|
# Ensure tp_size and tp_rank consistency with the base_layer.
|
|
self.tp_size = self.base_layer.tp_size
|
|
self.tp_rank = self.base_layer.tp_rank
|
|
self.device = _get_lora_device(self.base_layer)
|
|
self.output_slices: tuple[int, ...]
|
|
self.output_size: int
|
|
self.n_slices: int
|
|
|
|
def create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: PretrainedConfig | None = None,
|
|
) -> None:
|
|
self.lora_config = lora_config
|
|
#
|
|
if isinstance(self.base_layer, ReplicatedLinear):
|
|
lora_a_out_size = lora_config.max_lora_rank
|
|
lora_b_out_size = self.output_size
|
|
|
|
elif isinstance(self.base_layer, ColumnParallelLinear):
|
|
lora_a_out_size = (
|
|
lora_config.max_lora_rank
|
|
if not lora_config.fully_sharded_loras
|
|
else divide(lora_config.max_lora_rank, self.tp_size)
|
|
)
|
|
lora_b_out_size = self.output_size
|
|
|
|
elif isinstance(self.base_layer, RowParallelLinear):
|
|
lora_a_out_size = lora_config.max_lora_rank
|
|
lora_b_out_size = (
|
|
self.output_size
|
|
if not lora_config.fully_sharded_loras
|
|
else divide(self.output_size, self.tp_size)
|
|
)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
self.lora_a_stacked = tuple(
|
|
torch.zeros(
|
|
max_loras,
|
|
1,
|
|
lora_a_out_size,
|
|
self.input_size,
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
for _ in range(self.n_slices)
|
|
)
|
|
self.lora_b_stacked = tuple(
|
|
torch.zeros(
|
|
max_loras,
|
|
1,
|
|
lora_b_out_size,
|
|
lora_config.max_lora_rank,
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
for _ in range(self.n_slices)
|
|
)
|
|
self.output_slices = (self.lora_b_stacked[0].shape[2],)
|
|
|
|
def reset_lora(self, index: int):
|
|
for s_index in range(self.n_slices):
|
|
self.lora_a_stacked[s_index][index] = 0
|
|
self.lora_b_stacked[s_index][index] = 0
|
|
|
|
def set_lora(
|
|
self,
|
|
index: int,
|
|
lora_a: torch.Tensor,
|
|
lora_b: torch.Tensor,
|
|
embeddings_tensor: torch.Tensor | None,
|
|
):
|
|
# Except for QKVParallelLinearWithLoRA and
|
|
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
|
# store weights in a tuple of size 1. These two layers will
|
|
# override this function.
|
|
assert (
|
|
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
|
|
)
|
|
|
|
self.reset_lora(index)
|
|
if self.tp_size > 1:
|
|
lora_a = self.slice_lora_a(lora_a)
|
|
lora_b = self.slice_lora_b(lora_b)
|
|
|
|
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
|
lora_a, non_blocking=True
|
|
)
|
|
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
|
lora_b, non_blocking=True
|
|
)
|
|
|
|
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
|
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
|
|
|
# In Transformers modeling backend, x and output have extra batch dimension like
|
|
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
|
# therefore we need to flatten the batch dimensions.
|
|
if x.ndim == 3 and output.ndim == 3:
|
|
output = output.flatten(0, 1)
|
|
x = x.flatten(0, 1)
|
|
|
|
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
|
|
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
|
|
)
|
|
if not current_platform.can_update_inplace():
|
|
output = lora_output
|
|
|
|
return output
|
|
|
|
@property
|
|
def weight(self) -> torch.Tensor:
|
|
# unquantizedLinear
|
|
if hasattr(self.base_layer, "weight"):
|
|
return self.base_layer.weight
|
|
# Compressed Tensor
|
|
elif hasattr(self.base_layer, "weight_packed"):
|
|
return self.base_layer.weight_packed
|
|
# GPTQ/AWQ
|
|
elif hasattr(self.base_layer, "qweight"):
|
|
return self.base_layer.qweight
|
|
# marlin
|
|
elif hasattr(self.base_layer, "B"):
|
|
return self.base_layer.B
|
|
# HQQ marlin
|
|
elif hasattr(self.base_layer, "W_q"):
|
|
return self.base_layer.W_q
|
|
else:
|
|
raise ValueError(f"Unsupported base layer: {self.base_layer}")
|
|
|
|
@property
|
|
def bias(self) -> torch.Tensor | None:
|
|
if hasattr(self.base_layer, "bias"):
|
|
return self.base_layer.bias
|
|
else:
|
|
return None
|