init
This commit is contained in:
0
vllm/lora/__init__.py
Normal file
0
vllm/lora/__init__.py
Normal file
262
vllm/lora/fully_sharded_layers.py
Normal file
262
vllm/lora/fully_sharded_layers.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# pylint: disable=unused-argument
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
return (can_replace(*args, **kwargs)
|
||||
and kwargs['lora_config'].fully_sharded_loras)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
# these layers are based on the tensor parallelism strategy given in
|
||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||
# https://arxiv.org/abs/2311.03285.
|
||||
|
||||
|
||||
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked.shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
buffer = tensor_model_parallel_all_gather(buffer)
|
||||
bgmv(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
# now have column partitioned output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
def _mcp_apply_weights(x, bias, layer):
|
||||
"""
|
||||
MergedColumnParallelLinearWithShardedLoRA and
|
||||
QKVParallelLinearWithShardedLora share the same
|
||||
LoRa weight application method.
|
||||
|
||||
The main difference is the step by shard_size for lora_b which can
|
||||
vary for QKVParallelLinearWithShardedLora but is constant for
|
||||
MergedColumnParallelLinearWithShardedLoRA.
|
||||
"""
|
||||
# expecting 2 for column parallel and 3 for qkv
|
||||
n = len(layer.lora_a_stacked)
|
||||
output = layer.base_layer.linear_method.apply_weights(
|
||||
layer.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
for idx in range(n):
|
||||
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
|
||||
layer.indices[:layer.indices_len[0]], 0, 1.0)
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
left_offset = 0
|
||||
for idx in range(n):
|
||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
||||
dispatch_bgmv_low_level(output, buffers[idx],
|
||||
layer.lora_b_stacked[idx],
|
||||
layer.indices[:layer.indices_len[0]], 0, 1.0,
|
||||
left_offset, shard_size)
|
||||
left_offset += shard_size
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
# now have column partitioned and packed output
|
||||
return output
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithShardedLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||
output_start_idx = self.tp_rank * output_shard_size
|
||||
lora_a = [
|
||||
lora_a[i][:, output_start_idx:output_start_idx + output_shard_size]
|
||||
for i in range(2)
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply_weights(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
||||
"""
|
||||
Differs from QKVParallelLinearWithLora by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]:
|
||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||
lora_a = [
|
||||
lora_a[i][:, start_idx[i]:start_idx[i] +
|
||||
shard_size[i]] if lora_a[i] is not None else None
|
||||
for i in range(3)
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply_weights(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from RowParallelLinearWithLoRA by slicing the
|
||||
LoRA B's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the output dim.
|
||||
This yields a combined partial sum from the row parallel base
|
||||
layer and column partitioned output from the LoRA.
|
||||
"""
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
return lora_b
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer, x)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0,
|
||||
start_idx, shard_size)
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
1181
vllm/lora/layers.py
Normal file
1181
vllm/lora/layers.py
Normal file
File diff suppressed because it is too large
Load Diff
167
vllm/lora/lora.py
Normal file
167
vllm/lora/lora.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class LoRALayerWeights:
|
||||
"""LoRA weights for a layer composed of two low rank matrixes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
rank: int,
|
||||
lora_alpha: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||
scaling: Optional[float] = None,
|
||||
) -> None:
|
||||
self.module_name = module_name
|
||||
self.rank = rank
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_a = lora_a
|
||||
self.lora_b = lora_b
|
||||
self.embeddings_tensor = embeddings_tensor
|
||||
|
||||
if scaling is None:
|
||||
self.scaling = self.lora_alpha / self.rank
|
||||
else:
|
||||
self.scaling = scaling
|
||||
|
||||
def optimize(self) -> "LoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
if self.scaling == 1:
|
||||
return self
|
||||
self.lora_b *= self.scaling
|
||||
self.scaling = 1
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_dim(self) -> int:
|
||||
return self.lora_a.shape[0]
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self.lora_b.shape[1]
|
||||
|
||||
@property
|
||||
def is_packed(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return self.embeddings_tensor.shape[
|
||||
0] if self.embeddings_tensor is not None else 0
|
||||
|
||||
@classmethod
|
||||
def create_dummy_lora_weights(
|
||||
cls,
|
||||
module_name: str,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
lora_a = torch.zeros([input_dim, rank],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
lora_b = torch.zeros([rank, output_dim],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
embeddings_tensor = torch.rand(
|
||||
10,
|
||||
embeddings_tensor_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory) if embeddings_tensor_dim else None
|
||||
return cls(
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
|
||||
|
||||
class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
"""LoRA used for packed layers (eg. qkv_proj)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
rank: int,
|
||||
lora_alphas: List[Optional[int]],
|
||||
lora_a: List[Optional[torch.Tensor]],
|
||||
lora_b: List[Optional[torch.Tensor]],
|
||||
scaling: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
module_name=module_name,
|
||||
rank=rank,
|
||||
lora_alpha=0,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
scaling=scaling, # type: ignore
|
||||
embeddings_tensor=None,
|
||||
)
|
||||
self.lora_alphas = lora_alphas
|
||||
if scaling is None:
|
||||
self.scaling = [ # type: ignore
|
||||
lora_alpha / self.rank # type: ignore # noqa
|
||||
for lora_alpha in self.lora_alphas
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def pack(
|
||||
cls, loras: List[Optional["LoRALayerWeights"]]
|
||||
) -> "PackedLoRALayerWeights":
|
||||
"""Pack a list of LoRAs into a single LoRA.
|
||||
|
||||
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
||||
"""
|
||||
first_lora = next(lora for lora in loras if lora is not None)
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
lora.optimize()
|
||||
rank = first_lora.rank
|
||||
module_name = first_lora.module_name
|
||||
obj = cls(
|
||||
module_name,
|
||||
rank,
|
||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||
[lora.lora_a if lora is not None else None for lora in loras],
|
||||
[lora.lora_b if lora is not None else None for lora in loras],
|
||||
scaling=[
|
||||
1 if lora is not None else None # type: ignore
|
||||
for lora in loras
|
||||
])
|
||||
return obj
|
||||
|
||||
def optimize(self) -> "PackedLoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
for i in range(len(self.lora_b)):
|
||||
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
|
||||
continue
|
||||
self.lora_b[i] *= self.scaling[i] # type: ignore
|
||||
self.scaling[i] = 1 # type: ignore
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_dim(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_packed(self) -> bool:
|
||||
return True
|
||||
645
vllm/lora/models.py
Normal file
645
vllm/lora/models.py
Normal file
@@ -0,0 +1,645 @@
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.utils import LRUCache, is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GLOBAL_LORA_ID = 0
|
||||
|
||||
|
||||
def convert_mapping(
|
||||
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int, vocab_size: int, extra_vocab_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
|
||||
"""Converts LoRAMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
||||
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
||||
max_loras: Maximum number of LoRAs.
|
||||
vocab_size: Model vocab size.
|
||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors:
|
||||
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
LoRA indices.
|
||||
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
||||
LoRA indices for sampler. For generation, this will be the
|
||||
same as base_indicies. For prefill, this will map requests
|
||||
to LoRA indices.
|
||||
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
||||
requests to LoRA indices for sampler with padding.
|
||||
Same as sampler_indicies, but -1 is replaced with
|
||||
max_loras.
|
||||
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
||||
requests to embedding indices. First row is for embeddings
|
||||
added by the LoRAs, second row is for the LoRA.lora_a
|
||||
embeddings.
|
||||
indices_len: List of lengths of the above tensors.
|
||||
"""
|
||||
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
||||
embedding_indices = index_mapping_indices.copy()
|
||||
lora_indices = index_mapping_indices.copy()
|
||||
prompt_mapping: List[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
for x in mapping.prompt_mapping
|
||||
]
|
||||
lora_idx = None
|
||||
for i in range(len(index_mapping_indices)):
|
||||
# TODO index can be slow. optimize
|
||||
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
||||
if index_mapping_indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||
index_mapping_indices[i] = i
|
||||
lora_indices[i] = lora_idx
|
||||
|
||||
indices = torch.tensor(
|
||||
[index_mapping_indices, lora_indices, embedding_indices],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
embeddings_indices = torch.stack([
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size)
|
||||
])
|
||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping_tensor
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = (
|
||||
torch.arange(
|
||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
||||
(sampler_indices_padded * len(sampler_indices_padded)))
|
||||
indices_len = [
|
||||
base_indices.shape[-1], sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
|
||||
]
|
||||
|
||||
return (base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, indices_len)
|
||||
|
||||
|
||||
def get_lora_id():
|
||||
global _GLOBAL_LORA_ID
|
||||
_GLOBAL_LORA_ID += 1
|
||||
return _GLOBAL_LORA_ID
|
||||
|
||||
|
||||
class LoRAModel:
|
||||
"""A LoRA fine-tuned model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
loras: Dict[str, LoRALayerWeights],
|
||||
) -> None:
|
||||
self.id = lora_model_id
|
||||
assert (lora_model_id >
|
||||
0), f"a valid lora id should be greater than 0, got {self.id}"
|
||||
self.rank = rank
|
||||
self.loras: Dict[str, LoRALayerWeights] = loras
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return max(lora.extra_vocab_size
|
||||
for lora in self.loras.values()) if self.loras else 0
|
||||
|
||||
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
|
||||
"""Get LoRA for a given module by name"""
|
||||
return self.loras.get(module_name, None)
|
||||
|
||||
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
||||
@classmethod
|
||||
def from_lora_tensors(
|
||||
cls,
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
lora_alpha: int,
|
||||
tensors: Dict[str, torch.Tensor],
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a dictionary of tensors."""
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
loras: Dict[str, LoRALayerWeights] = {}
|
||||
for tensor_name, tensor in tensors.items():
|
||||
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
|
||||
if module_name not in loras:
|
||||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
assert embedding_modules is not None
|
||||
embeddings_module = next(
|
||||
(k for k in embedding_modules if k in module_name),
|
||||
None)
|
||||
if embeddings_module:
|
||||
lora_embeddings_tensor = embeddings[
|
||||
embedding_modules[embeddings_module]].to(
|
||||
device=device, dtype=dtype)
|
||||
if pin_memory:
|
||||
lora_embeddings_tensor = (
|
||||
lora_embeddings_tensor.pin_memory())
|
||||
loras[module_name] = LoRALayerWeights(module_name, rank,
|
||||
lora_alpha, None, None,
|
||||
lora_embeddings_tensor)
|
||||
if is_lora_a:
|
||||
loras[module_name].lora_a = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
if pin_memory:
|
||||
loras[module_name].lora_a = loras[
|
||||
module_name].lora_a.pin_memory()
|
||||
else:
|
||||
loras[module_name].lora_b = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
assert embedding_padding_modules is not None
|
||||
if any(name in module_name
|
||||
for name in embedding_padding_modules
|
||||
) and target_embedding_padding is not None:
|
||||
lora_b = loras[module_name].lora_b
|
||||
assert target_embedding_padding >= lora_b.shape[1]
|
||||
addition = target_embedding_padding - lora_b.shape[1]
|
||||
loras[module_name].lora_b = torch.nn.functional.pad(
|
||||
lora_b, (0, addition))
|
||||
if pin_memory:
|
||||
loras[module_name].lora_b = loras[
|
||||
module_name].lora_b.pin_memory()
|
||||
|
||||
for lora in loras.values():
|
||||
lora.optimize()
|
||||
return cls(lora_model_id, rank, loras)
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
lora_dir: str,
|
||||
expected_lora_modules: List[str],
|
||||
lora_model_id: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a local checkpoint."""
|
||||
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||
new_embeddings_tensor_path = os.path.join(
|
||||
lora_dir, "new_embeddings.safetensors")
|
||||
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
||||
"new_embeddings.bin")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
target_modules = config["target_modules"]
|
||||
unexpected_modules = []
|
||||
for module in target_modules:
|
||||
# Compatible with more modules, such as:layers.11.self_attn.k_proj
|
||||
part_name = module.split(".")[-1]
|
||||
if part_name not in expected_lora_modules:
|
||||
unexpected_modules.append(module)
|
||||
# loaded lora's target modules must be a subset of expected_lora_modules
|
||||
if unexpected_modules:
|
||||
raise ValueError(
|
||||
f"While loading {lora_dir}, expected"
|
||||
f" target modules in {expected_lora_modules}"
|
||||
f" but received {unexpected_modules}."
|
||||
f" Please verify that the loaded LoRA module is correct")
|
||||
if os.path.isfile(lora_tensor_path):
|
||||
tensors = safetensors.torch.load_file(lora_tensor_path)
|
||||
elif os.path.isfile(lora_bin_file_path):
|
||||
tensors = torch.load(lora_bin_file_path)
|
||||
else:
|
||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||
|
||||
embeddings = None
|
||||
if os.path.isfile(new_embeddings_tensor_path):
|
||||
embeddings = safetensors.torch.load_file(
|
||||
new_embeddings_tensor_path)
|
||||
elif os.path.isfile(new_embeddings_bin_file_path):
|
||||
embeddings = torch.load(new_embeddings_bin_file_path)
|
||||
|
||||
rank = config["r"]
|
||||
lora_alpha = config["lora_alpha"]
|
||||
return cls.from_lora_tensors(
|
||||
lora_model_id=get_lora_id()
|
||||
if lora_model_id is None else lora_model_id,
|
||||
rank=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
tensors=tensors,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
embeddings=embeddings,
|
||||
target_embedding_padding=target_embedding_padding,
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embedding_padding_modules,
|
||||
)
|
||||
|
||||
|
||||
class LoRAModelManager:
|
||||
"""A manager that manages multiple LoRA-fine-tuned models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
max_num_seqs: the maximum number of sequences model can run in a
|
||||
single batch.
|
||||
max_num_batched_tokens: the maximum number of tokens model can run
|
||||
in a single batch.
|
||||
vocab_size: the vocab size of the model.
|
||||
lora_config: the LoRA configuration.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
self.max_num_seqs = max_num_seqs
|
||||
assert self.capacity >= self.lora_slots
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.base_indices = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.embeddings_indices = torch.empty(2,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
# 4 is the number of indicies tensors defined above
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 4
|
||||
|
||||
self.model: nn.Module = model
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
self.model.supported_lora_modules)
|
||||
self.packed_modules_mapping = copy.deepcopy(
|
||||
self.model.packed_modules_mapping)
|
||||
self.packed_modules: Dict[str, List[str]] = {}
|
||||
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self._active_loras: Dict[int, None] = {}
|
||||
self._last_mapping: Optional[LoRAMapping] = None
|
||||
self._create_lora_modules()
|
||||
self.model.lora_manager = self
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
return self.lora_config.max_cpu_loras
|
||||
|
||||
@property
|
||||
def lora_slots(self) -> int:
|
||||
return self.lora_config.max_loras
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_loras)
|
||||
|
||||
def activate_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
"""Move LoRA into a GPU buffer to be used in the forward pass."""
|
||||
if lora_id in self._active_loras:
|
||||
return False
|
||||
first_free_slot = next(
|
||||
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
|
||||
if lora_id is None), None)
|
||||
if first_free_slot is None:
|
||||
raise ValueError("No free lora slots")
|
||||
index, _ = first_free_slot
|
||||
self._active_loras[lora_id] = None
|
||||
lora_model = self._registered_loras[lora_id]
|
||||
logger.debug("Activating LoRA. int id: %d, slot index: %d",
|
||||
lora_model.id, index)
|
||||
self.lora_index_to_id[index] = lora_model.id
|
||||
for module_name, module in self.modules.items():
|
||||
module_lora = lora_model.get_lora(module_name)
|
||||
if module_lora:
|
||||
module_lora.optimize()
|
||||
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
||||
module_lora.embeddings_tensor)
|
||||
else:
|
||||
module.reset_lora(index)
|
||||
return True
|
||||
|
||||
def _deactivate_lora(self, lora_id: int):
|
||||
try:
|
||||
index = self.lora_index_to_id.index(lora_id)
|
||||
self.lora_index_to_id[index] = None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def deactivate_lora(self, lora_id: int) -> bool:
|
||||
"""Remove a LoRA from a GPU buffer."""
|
||||
if lora_id in self._active_loras:
|
||||
self._deactivate_lora(lora_id)
|
||||
self._active_loras.pop(lora_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _add_lora(self, lora: LoRAModel):
|
||||
self._create_merged_loras_inplace(lora)
|
||||
self._registered_loras[lora.id] = lora
|
||||
|
||||
def add_lora(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager CPU cache."""
|
||||
if lora.id not in self._registered_loras:
|
||||
if len(self._registered_loras) >= self.capacity:
|
||||
raise RuntimeError("No free LoRA slots.")
|
||||
self._add_lora(lora)
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove a LoRAModel from the manager CPU cache."""
|
||||
# TODO: should we check active lora?
|
||||
self.deactivate_lora(lora_id)
|
||||
return bool(self._registered_loras.pop(lora_id, None))
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
|
||||
self.lora_slots + 1, self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size)
|
||||
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||
sampler_indices_padded)
|
||||
self.embeddings_indices[:embeddings_indices.
|
||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||
embeddings_indices)
|
||||
# Maintain the reference
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
|
||||
if self._last_mapping != lora_mapping:
|
||||
self._set_lora_mapping(lora_mapping)
|
||||
self._last_mapping = lora_mapping
|
||||
|
||||
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||
"""List all registered LoRAModels."""
|
||||
return dict(self._registered_loras)
|
||||
|
||||
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
|
||||
return self._registered_loras.get(lora_id, None)
|
||||
|
||||
def remove_all_loras(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
self._registered_loras.clear()
|
||||
self.lora_index_to_id = [None] * self.lora_slots
|
||||
self._active_loras.clear()
|
||||
|
||||
def _create_lora_modules(self):
|
||||
for module_name, module in self.model.named_modules():
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
parts = module_name.split(".")[-1]
|
||||
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
||||
new_module = replace_submodule(
|
||||
self.model, module_name,
|
||||
from_layer(module, self.lora_slots, self.lora_config,
|
||||
packed_moduled_lst, self.model.config))
|
||||
# (yard1): TODO make this more robust
|
||||
if "lm_head" in module_name:
|
||||
logits_processor_module = self.model.get_submodule(
|
||||
"logits_processor")
|
||||
new_module = replace_submodule(
|
||||
self.model, "logits_processor",
|
||||
from_layer_logits_processor(logits_processor_module,
|
||||
module, self.lora_slots,
|
||||
self.lora_config,
|
||||
self.model.config))
|
||||
self.register_module(module_name, new_module)
|
||||
self._register_packed_modules(module_name)
|
||||
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
||||
self.sampler_indices_padded,
|
||||
self.embeddings_indices, self.indices_len)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA)
|
||||
self.modules[module_name] = module
|
||||
|
||||
def create_dummy_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
rank: int,
|
||||
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
||||
"""Create zero-initialized LoRAModel for warmup."""
|
||||
model = LoRAModel(lora_id, rank, {})
|
||||
for module_name, module in self.model.named_modules():
|
||||
if not self._match_target_modules(module_name) or not isinstance(
|
||||
module, BaseLayerWithLoRA):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
if module_name not in self.packed_modules:
|
||||
assert embedding_modules is not None
|
||||
if parts[-1] in embedding_modules:
|
||||
input_dim = (module.base_layer.org_vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size if
|
||||
hasattr(module.base_layer, "org_vocab_size")
|
||||
else module.base_layer.weight.shape[1])
|
||||
output_dim = module.base_layer.embedding_dim if hasattr(
|
||||
module.base_layer,
|
||||
"embedding_dim") else module.base_layer.weight.shape[0]
|
||||
embeddings_tensor_dim = (module.base_layer.embedding_dim if
|
||||
hasattr(module.base_layer,
|
||||
"embedding_dim") else
|
||||
module.base_layer.weight.shape[1])
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
input_dim,
|
||||
output_dim,
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
embeddings_tensor_dim=embeddings_tensor_dim)
|
||||
else:
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
module.lora_a_stacked.shape[-1],
|
||||
module.lora_b_stacked.shape[-2],
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
)
|
||||
lora.optimize()
|
||||
else:
|
||||
parts = module_name.split(".")
|
||||
replacements = self.packed_modules_mapping[parts[-1]]
|
||||
subloras: List[Optional["LoRALayerWeights"]] = []
|
||||
for i, r in enumerate(replacements):
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name + "." + r,
|
||||
module.lora_a_stacked[i].shape[-1],
|
||||
module.lora_b_stacked[i].shape[-2],
|
||||
rank,
|
||||
module.lora_a_stacked[i].dtype,
|
||||
"cpu",
|
||||
)
|
||||
lora.optimize()
|
||||
subloras.append(lora)
|
||||
lora = PackedLoRALayerWeights.pack(subloras)
|
||||
model.loras[module_name] = lora
|
||||
return model
|
||||
|
||||
def _match_target_modules(self, module_name: str):
|
||||
return any(
|
||||
re.match(
|
||||
r".*\.{target_module}$".format(target_module=target_module),
|
||||
module_name) or target_module == module_name
|
||||
for target_module in self.supported_lora_modules)
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
module_name = parts[-1]
|
||||
replacements = self.packed_modules_mapping.get(module_name, [])
|
||||
# When replacements is less than or equal to 1, it indicates that this
|
||||
# module is not a packed module.
|
||||
if len(replacements) <= 1:
|
||||
return
|
||||
prefix = ".".join(parts[:-1])
|
||||
self.packed_modules[module_full_name] = [
|
||||
prefix + "." + r if prefix else r for r in replacements
|
||||
]
|
||||
|
||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||
for module_name, new_module_names in self.packed_modules.items():
|
||||
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
replacement_loras.append(lora)
|
||||
if lora:
|
||||
has_replacement = True
|
||||
if not has_replacement:
|
||||
continue
|
||||
for i in range(len(replacement_loras)):
|
||||
if replacement_loras[i]:
|
||||
continue
|
||||
replacement_loras[i] = None
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
|
||||
|
||||
class LoRALRUCache(LRUCache[LoRAModel]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
|
||||
bool]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_lora_fn = deactivate_lora_fn
|
||||
|
||||
def _on_remove(self, key: int, value: LoRAModel):
|
||||
logger.debug("Removing LoRA. int id: %d", key)
|
||||
self.deactivate_lora_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
||||
|
||||
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
"""A model manager that manages multiple LoRAs with LRU cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config)
|
||||
self._registered_loras: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_lora)
|
||||
self._active_loras: LoRALRUCache = LoRALRUCache(
|
||||
self.lora_slots, self._deactivate_lora)
|
||||
|
||||
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||
"""List all registered LoRAModels."""
|
||||
return dict(self._registered_loras.cache)
|
||||
|
||||
def add_lora(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager."""
|
||||
if lora.id not in self._registered_loras:
|
||||
self._add_lora(lora)
|
||||
was_added = True
|
||||
else:
|
||||
# We always touch to update the LRU cache order
|
||||
self._registered_loras.touch(lora.id)
|
||||
was_added = False
|
||||
return was_added
|
||||
|
||||
def activate_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
if lora_id not in self._active_loras and len(
|
||||
self._active_loras) >= self.lora_slots:
|
||||
self._active_loras.remove_oldest()
|
||||
result = super().activate_lora(lora_id)
|
||||
# We always touch to update the LRU cache order
|
||||
self._active_loras.touch(lora_id)
|
||||
return result
|
||||
|
||||
def remove_oldest_lora(self) -> bool:
|
||||
if len(self._registered_loras) > 0:
|
||||
self._registered_loras.remove_oldest()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_lora_manager(
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
"""Create a LoRA adapter for a given model."""
|
||||
if not hasattr(model, "supported_lora_modules"):
|
||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||
lora_manager = lora_manager_cls(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
**kwargs)
|
||||
return lora_manager
|
||||
213
vllm/lora/punica.py
Normal file
213
vllm/lora/punica.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Based on code from https://github.com/punica-ai/punica
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _raise_import_error(e):
|
||||
if torch.cuda.get_device_capability() < (8, 0):
|
||||
raise ImportError(
|
||||
"punica LoRA kernels require compute capability >= 8.0") from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"punica LoRA kernels could not be imported. If you built vLLM "
|
||||
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
||||
"was set.") from e
|
||||
|
||||
|
||||
def bgmv(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
):
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
|
||||
matrices.
|
||||
indicies: Shape: `[B]`. Indices of the weight matrices.
|
||||
layer_idx: Layer index of the weight matrices.
|
||||
scale: Scaling factor.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||
|
||||
|
||||
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, indicies: torch.LongTensor,
|
||||
layer_idx: int, scale: float, y_offset: int,
|
||||
y_slice_size: int):
|
||||
"""
|
||||
Same as `bgmv` but you can operate on slices of y.
|
||||
Pass whole y, define y_offset and y_slice_size.
|
||||
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
|
||||
all of the transposed LoRA matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
y_offset: Offset to apply to the starting column of y.
|
||||
y_slice_size: Size of the y column slice.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
punica_kernels.dispatch_bgmv_low_level(
|
||||
y,
|
||||
x,
|
||||
w_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
scale,
|
||||
x.size(1),
|
||||
y_slice_size,
|
||||
y_offset,
|
||||
)
|
||||
|
||||
|
||||
def add_lora(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||
LoRA A matrices.
|
||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||
LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
# numerical inaccuracies that would otherwise happen
|
||||
# due to downcasting.
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
|
||||
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
||||
scale)
|
||||
|
||||
|
||||
def add_lora_slice(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Same as `add_lora` but you can operate on slices of y.
|
||||
Pass whole y, define y_offset and y_slice_size.
|
||||
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||
LoRA A matrices.
|
||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||
LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
y_offset: Offset to apply to the starting column of y.
|
||||
y_slice_size: Size of the y column slice.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
# numerical inaccuracies that would otherwise happen
|
||||
# due to downcasting.
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
punica_kernels.dispatch_bgmv_low_level(
|
||||
buffer,
|
||||
x,
|
||||
wa_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
1.0,
|
||||
x.size(1),
|
||||
buffer.size(1),
|
||||
0,
|
||||
)
|
||||
punica_kernels.dispatch_bgmv_low_level(
|
||||
y,
|
||||
buffer,
|
||||
wb_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
scale,
|
||||
buffer.size(1),
|
||||
y_slice_size,
|
||||
y_offset,
|
||||
)
|
||||
32
vllm/lora/request.py
Normal file
32
vllm/lora/request.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRARequest:
|
||||
"""
|
||||
Request for a LoRA adapter.
|
||||
|
||||
Note that this class should be be used internally. For online
|
||||
serving, it is recommended to not allow users to use this class but
|
||||
instead provide another layer of abstraction to prevent users from
|
||||
accessing unauthorized LoRA adapters.
|
||||
|
||||
lora_int_id must be globally unique for a given adapter.
|
||||
This is currently not enforced in vLLM.
|
||||
"""
|
||||
|
||||
lora_name: str
|
||||
lora_int_id: int
|
||||
lora_local_path: str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lora_int_id < 1:
|
||||
raise ValueError(
|
||||
f"lora_int_id must be > 0, got {self.lora_int_id}")
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(
|
||||
value, LoRARequest) and self.lora_int_id == value.lora_int_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.lora_int_id
|
||||
98
vllm/lora/utils.py
Normal file
98
vllm/lora/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA)
|
||||
# being imported for _all_lora_classes below
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||
VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora,
|
||||
MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA
|
||||
}
|
||||
|
||||
|
||||
def from_layer(layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
|
||||
for lora_cls in _all_lora_classes:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
if lora_cls.can_replace_layer(source_layer=layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config):
|
||||
ret = lora_cls(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
|
||||
def from_layer_logits_processor(
|
||||
layer: LogitsProcessor,
|
||||
lm_head: ParallelLMHead,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> LogitsProcessorWithLoRA:
|
||||
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
||||
lm_head.weight.dtype, lm_head.weight.device)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
|
||||
|
||||
def replace_submodule(model: nn.Module, module_name: str,
|
||||
new_module: nn.Module) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
|
||||
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
||||
"""Parse the name of lora weights.
|
||||
|
||||
args:
|
||||
name: the name of the fine-tuned LoRA, e.g.
|
||||
base_model.model.dense1.weight
|
||||
return:
|
||||
Tuple(module_name, is_lora_a):
|
||||
module_name: the name of the module, e.g. model.dense1,
|
||||
is_lora_a whether the tensor is lora_a or lora_b.
|
||||
"""
|
||||
parts = name.split(".")
|
||||
assert parts[0] == "base_model"
|
||||
assert parts[1] == "model"
|
||||
if parts[-1] == "weight":
|
||||
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
|
||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
||||
|
||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
||||
|
||||
raise ValueError(f"{name} is unsupported format")
|
||||
251
vllm/lora/worker_manager.py
Normal file
251
vllm/lora/worker_manager.py
Normal file
@@ -0,0 +1,251 @@
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Any, Dict, List, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AbstractWorkerLoRAManager(ABC):
|
||||
"""Abstract class for managing LoRA models on the worker side."""
|
||||
|
||||
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
||||
vocab_size: int, lora_config: LoRAConfig,
|
||||
device: torch.device):
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.lora_config = lora_config
|
||||
|
||||
@abstractproperty
|
||||
def is_enabled(self) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_loras(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> Set[int]:
|
||||
...
|
||||
|
||||
|
||||
class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Every request, the requested LoRAs will be loaded (unless they are already
|
||||
loaded), and every other LoRA will be unloaded."""
|
||||
|
||||
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
embedding_modules: Dict[str, str],
|
||||
embedding_padding_modules: List[str],
|
||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||
):
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
# Lazily initialized by create_lora_manager.
|
||||
self._lora_manager: LoRAModelManager
|
||||
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
||||
lora_config, device)
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
)
|
||||
self._lora_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
self._apply_loras(lora_requests)
|
||||
self._lora_manager.set_lora_mapping(lora_mapping)
|
||||
|
||||
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
loras_that_exist = self.list_loras()
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
}
|
||||
if len(loras_map) > self._lora_manager.lora_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||
"than the number of GPU LoRA slots "
|
||||
f"({self._lora_manager.lora_slots}).")
|
||||
|
||||
new_loras = set(loras_map)
|
||||
loras_to_add = new_loras - loras_that_exist
|
||||
loras_to_remove = loras_that_exist - new_loras
|
||||
|
||||
for lora_id in loras_to_remove:
|
||||
self.remove_lora(lora_id)
|
||||
|
||||
for lora_id in loras_to_add:
|
||||
self.add_lora(loras_map[lora_id])
|
||||
|
||||
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
|
||||
try:
|
||||
model = self._lora_manager.model
|
||||
supported_lora_modules = model.supported_lora_modules
|
||||
packed_modules_mapping = model.packed_modules_mapping
|
||||
expected_lora_modules = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(
|
||||
packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
lora = self._lora_model_cls.from_local_checkpoint(
|
||||
lora_request.lora_local_path,
|
||||
expected_lora_modules,
|
||||
lora_model_id=lora_request.lora_int_id,
|
||||
device="cpu",
|
||||
dtype=self.lora_config.lora_dtype,
|
||||
target_embedding_padding=self.vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
embedding_modules=self.embedding_modules,
|
||||
embedding_padding_modules=self.embedding_padding_modules,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Loading lora {lora_request.lora_local_path} failed") from e
|
||||
if lora.rank > self.lora_config.max_lora_rank:
|
||||
raise ValueError(
|
||||
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
||||
f"{self.lora_config.max_lora_rank}.")
|
||||
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
|
||||
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
|
||||
f"is greater than lora_extra_vocab_size "
|
||||
f"{self.lora_config.lora_extra_vocab_size}.")
|
||||
return lora
|
||||
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
return False
|
||||
return self._lora_manager.add_lora(
|
||||
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
||||
rank, self.embedding_modules))
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
return False
|
||||
lora = self._load_lora(lora_request)
|
||||
loaded = self._lora_manager.add_lora(lora)
|
||||
self._lora_manager.activate_lora(lora.id)
|
||||
return loaded
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self._lora_manager.remove_lora(lora_id)
|
||||
|
||||
def remove_all_loras(self):
|
||||
self._lora_manager.remove_all_loras()
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return set(self._lora_manager.list_loras())
|
||||
|
||||
|
||||
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
|
||||
(unless they are already loaded) and least recently used LoRAs will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_lora_manager_cls: Type[
|
||||
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
|
||||
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
)
|
||||
self._lora_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
}
|
||||
if len(loras_map) > self._lora_manager.lora_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||
"than the number of GPU LoRA slots "
|
||||
f"({self._lora_manager.lora_slots}).")
|
||||
for lora in loras_map.values():
|
||||
self.add_lora(lora)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id not in self.list_loras():
|
||||
# Remove before we load the new lora to save memory
|
||||
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
||||
assert isinstance(self._lora_manager, LRUCacheLoRAModelManager)
|
||||
self._lora_manager.remove_oldest_lora()
|
||||
lora = self._load_lora(lora_request)
|
||||
loaded = self._lora_manager.add_lora(lora)
|
||||
else:
|
||||
# If the lora is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._lora_manager.get_lora(
|
||||
lora_request.lora_int_id) is not None
|
||||
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
||||
return loaded
|
||||
Reference in New Issue
Block a user