init
This commit is contained in:
185
vllm/lora/layers/base_linear.py
Normal file
185
vllm/lora/layers/base_linear.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional, cast
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed.utils import divide
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
LinearBase, ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import BaseLayerWithLoRA
|
||||
from .utils import _get_lora_device
|
||||
|
||||
|
||||
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: LinearBase):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.input_size = self.base_layer.input_size
|
||||
# Ensure tp_size and tp_rank consistency with the base_layer.
|
||||
self.tp_size = self.base_layer.tp_size
|
||||
self.tp_rank = self.base_layer.tp_rank
|
||||
self.device = _get_lora_device(self.base_layer)
|
||||
self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None
|
||||
self.output_slices: tuple[int, ...]
|
||||
self.output_size: int
|
||||
self.n_slices: int
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
#
|
||||
if isinstance(self.base_layer, ReplicatedLinear):
|
||||
lora_a_out_size = lora_config.max_lora_rank
|
||||
lora_b_out_size = self.output_size
|
||||
|
||||
elif isinstance(self.base_layer, ColumnParallelLinear):
|
||||
lora_a_out_size = (lora_config.max_lora_rank if
|
||||
not lora_config.fully_sharded_loras else divide(
|
||||
lora_config.max_lora_rank, self.tp_size))
|
||||
lora_b_out_size = self.output_size
|
||||
|
||||
elif isinstance(self.base_layer, RowParallelLinear):
|
||||
lora_a_out_size = lora_config.max_lora_rank
|
||||
lora_b_out_size = (self.output_size if
|
||||
not lora_config.fully_sharded_loras else divide(
|
||||
self.output_size, self.tp_size))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_out_size,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
) for _ in range(self.n_slices))
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_b_out_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
) for _ in range(self.n_slices))
|
||||
if lora_config.bias_enabled:
|
||||
lora_bias_out_size = lora_b_out_size
|
||||
self.lora_bias_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_bias_out_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
) for _ in range(self.n_slices))
|
||||
self.output_slices = (self.lora_b_stacked[0].shape[2], )
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
for s_index in range(self.n_slices):
|
||||
self.lora_a_stacked[s_index][index] = 0
|
||||
self.lora_b_stacked[s_index][index] = 0
|
||||
if self.lora_config.bias_enabled:
|
||||
# Make mypy happy
|
||||
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
||||
self.lora_bias_stacked)
|
||||
self.lora_bias_stacked[s_index][index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
lora_bias: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Except for QKVParallelLinearWithLoRA and
|
||||
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
||||
# store weights in a tuple of size 1. These two layers will
|
||||
# override this function.
|
||||
assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) ==
|
||||
self.n_slices == 1)
|
||||
|
||||
self.reset_lora(index)
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
if lora_bias is not None:
|
||||
lora_bias = self.slice_bias(lora_bias)
|
||||
|
||||
self.lora_a_stacked[0][index,
|
||||
0, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True)
|
||||
self.lora_b_stacked[0][index,
|
||||
0, :lora_b.shape[0], :lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True)
|
||||
if lora_bias is not None:
|
||||
|
||||
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
|
||||
self.lora_bias_stacked)
|
||||
assert len(self.lora_bias_stacked)
|
||||
self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_(
|
||||
lora_bias, non_blocking=True)
|
||||
|
||||
def apply(self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
# In transformers backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
if x.ndim == 3 and output.ndim == 3:
|
||||
output = output.flatten(0, 1)
|
||||
x = x.flatten(0, 1)
|
||||
|
||||
lora_output: Optional[
|
||||
torch.Tensor] = self.punica_wrapper.add_lora_linear(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked,
|
||||
self.lora_bias_stacked, 1.0, self.output_slices)
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def weight(self) -> torch.Tensor:
|
||||
|
||||
# unquantizedLinear
|
||||
if hasattr(self.base_layer, "weight"):
|
||||
return self.base_layer.weight
|
||||
# Compressed Tensor
|
||||
elif hasattr(self.base_layer, "weight_packed"):
|
||||
return self.base_layer.weight_packed
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(self.base_layer, "qweight"):
|
||||
return self.base_layer.qweight
|
||||
# marlin
|
||||
elif hasattr(self.base_layer, "B"):
|
||||
return self.base_layer.B
|
||||
# HQQ marlin
|
||||
elif hasattr(self.base_layer, "W_q"):
|
||||
return self.base_layer.W_q
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {self.base_layer}")
|
||||
|
||||
@property
|
||||
def bias(self) -> Optional[torch.Tensor]:
|
||||
if hasattr(self.base_layer, "bias"):
|
||||
return self.base_layer.bias
|
||||
else:
|
||||
return None
|
||||
Reference in New Issue
Block a user