Sync from v0.13
This commit is contained in:
66
vllm/lora/layers/base.py
Normal file
66
vllm/lora/layers/base.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.punica_wrapper import PunicaWrapperBase
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def slice_lora_a(
|
||||
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora a if splitting for tensor parallelism."""
|
||||
...
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora b if splitting with tensor parallelism."""
|
||||
...
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
...
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
...
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
punica_wrapper,
|
||||
):
|
||||
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user