68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config.lora import LoRAConfig
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.lora.punica_wrapper import PunicaWrapperBase
|
|
|
|
|
|
class BaseLayerWithLoRA(nn.Module):
|
|
def slice_lora_a(
|
|
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
|
) -> torch.Tensor | list[torch.Tensor | None]:
|
|
"""Slice lora a if splitting for tensor parallelism."""
|
|
...
|
|
|
|
def slice_lora_b(
|
|
self, lora_b: torch.Tensor | list[torch.Tensor | None]
|
|
) -> torch.Tensor | list[torch.Tensor | None]:
|
|
"""Slice lora b if splitting with tensor parallelism."""
|
|
...
|
|
|
|
def create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: PretrainedConfig | None = None,
|
|
) -> None:
|
|
"""Initializes lora matrices."""
|
|
...
|
|
|
|
def reset_lora(self, index: int):
|
|
"""Resets the lora weights at index back to 0."""
|
|
...
|
|
|
|
def set_lora(
|
|
self,
|
|
index: int,
|
|
lora_a: torch.Tensor,
|
|
lora_b: torch.Tensor,
|
|
embeddings_tensor: torch.Tensor | None,
|
|
):
|
|
"""Overwrites lora tensors at index."""
|
|
...
|
|
|
|
def set_mapping(
|
|
self,
|
|
punica_wrapper,
|
|
):
|
|
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
|
|
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None,
|
|
) -> bool:
|
|
"""Returns True if the layer can be replaced by this LoRA layer."""
|
|
raise NotImplementedError
|