# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING import torch import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig if TYPE_CHECKING: from vllm.lora.punica_wrapper import PunicaWrapperBase class BaseLayerWithLoRA(nn.Module): def slice_lora_a( self, lora_a: torch.Tensor | list[torch.Tensor | None] ) -> torch.Tensor | list[torch.Tensor | None]: """Slice lora a if splitting for tensor parallelism.""" ... def slice_lora_b( self, lora_b: torch.Tensor | list[torch.Tensor | None] ) -> torch.Tensor | list[torch.Tensor | None]: """Slice lora b if splitting with tensor parallelism.""" ... def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: PretrainedConfig | None = None, ) -> None: """Initializes lora matrices.""" ... def reset_lora(self, index: int): """Resets the lora weights at index back to 0.""" ... def set_lora( self, index: int, lora_a: torch.Tensor | list[torch.Tensor], lora_b: torch.Tensor | list[torch.Tensor], ): """Overwrites lora tensors at index.""" ... def set_mapping( self, punica_wrapper, ): self.punica_wrapper: PunicaWrapperBase = punica_wrapper @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: """Returns True if the layer can be replaced by this LoRA layer.""" raise NotImplementedError