# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Optional, cast import torch from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig from vllm.distributed.utils import divide # yapf: disable from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, ReplicatedLinear, RowParallelLinear) from vllm.platforms import current_platform from .base import BaseLayerWithLoRA from .utils import _get_lora_device class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: LinearBase): super().__init__() self.base_layer = base_layer self.input_size = self.base_layer.input_size # Ensure tp_size and tp_rank consistency with the base_layer. self.tp_size = self.base_layer.tp_size self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(self.base_layer) self.lora_bias_stacked: Optional[tuple[torch.Tensor, ...]] = None self.output_slices: tuple[int, ...] self.output_size: int self.n_slices: int def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: self.lora_config = lora_config # if isinstance(self.base_layer, ReplicatedLinear): lora_a_out_size = lora_config.max_lora_rank lora_b_out_size = self.output_size elif isinstance(self.base_layer, ColumnParallelLinear): lora_a_out_size = (lora_config.max_lora_rank if not lora_config.fully_sharded_loras else divide( lora_config.max_lora_rank, self.tp_size)) lora_b_out_size = self.output_size elif isinstance(self.base_layer, RowParallelLinear): lora_a_out_size = lora_config.max_lora_rank lora_b_out_size = (self.output_size if not lora_config.fully_sharded_loras else divide( self.output_size, self.tp_size)) else: raise NotImplementedError self.lora_a_stacked = tuple( torch.zeros( max_loras, 1, lora_a_out_size, self.input_size, dtype=lora_config.lora_dtype, device=self.device, ) for _ in range(self.n_slices)) self.lora_b_stacked = tuple( torch.zeros( max_loras, 1, lora_b_out_size, lora_config.max_lora_rank, dtype=lora_config.lora_dtype, device=self.device, ) for _ in range(self.n_slices)) if lora_config.bias_enabled: lora_bias_out_size = lora_b_out_size self.lora_bias_stacked = tuple( torch.zeros( max_loras, 1, lora_bias_out_size, dtype=lora_config.lora_dtype, device=self.device, ) for _ in range(self.n_slices)) self.output_slices = (self.lora_b_stacked[0].shape[2], ) def reset_lora(self, index: int): for s_index in range(self.n_slices): self.lora_a_stacked[s_index][index] = 0 self.lora_b_stacked[s_index][index] = 0 if self.lora_config.bias_enabled: # Make mypy happy self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) self.lora_bias_stacked[s_index][index] = 0 def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], lora_bias: Optional[torch.Tensor] = None, ): # Except for QKVParallelLinearWithLoRA and # MergedColumnParallelLinearWithLoRA, all other linear LoRA layers # store weights in a tuple of size 1. These two layers will # override this function. assert (len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1) self.reset_lora(index) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) if lora_bias is not None: lora_bias = self.slice_bias(lora_bias) self.lora_a_stacked[0][index, 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) self.lora_b_stacked[0][index, 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( lora_b, non_blocking=True) if lora_bias is not None: self.lora_bias_stacked = cast(tuple[torch.Tensor, ...], self.lora_bias_stacked) assert len(self.lora_bias_stacked) self.lora_bias_stacked[0][index, 0, :lora_bias.shape[0]].copy_( lora_bias, non_blocking=True) def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) # In transformers backend, x and output have extra batch dimension like # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), # therefore we need to flatten the batch dimensions. if x.ndim == 3 and output.ndim == 3: output = output.flatten(0, 1) x = x.flatten(0, 1) lora_output: Optional[ torch.Tensor] = self.punica_wrapper.add_lora_linear( output, x, self.lora_a_stacked, self.lora_b_stacked, self.lora_bias_stacked, 1.0, self.output_slices) if not current_platform.can_update_inplace(): output = lora_output return output @property def weight(self) -> torch.Tensor: # unquantizedLinear if hasattr(self.base_layer, "weight"): return self.base_layer.weight # Compressed Tensor elif hasattr(self.base_layer, "weight_packed"): return self.base_layer.weight_packed # GPTQ/AWQ elif hasattr(self.base_layer, "qweight"): return self.base_layer.qweight # marlin elif hasattr(self.base_layer, "B"): return self.base_layer.B # HQQ marlin elif hasattr(self.base_layer, "W_q"): return self.base_layer.W_q else: raise ValueError(f"Unsupported base layer: {self.base_layer}") @property def bias(self) -> Optional[torch.Tensor]: if hasattr(self.base_layer, "bias"): return self.base_layer.bias else: return None