This commit is contained in:
2025-08-13 19:46:19 +08:00
commit 5d2e7edf78
1232 changed files with 361215 additions and 0 deletions

0
lora/__init__.py Normal file
View File

View File

@@ -0,0 +1,355 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# pylint: disable=unused-argument
from typing import TYPE_CHECKING, Optional, Union, cast
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.distributed.communication_op import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA)
from vllm.platforms import current_platform
if TYPE_CHECKING:
pass
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (can_replace(*args, **kwargs)
and kwargs["lora_config"].fully_sharded_loras)
return dec
def _mcp_apply(x, bias, layer: ColumnParallelLinearWithLoRA):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (layer.n_slices == len(layer.lora_a_stacked) == len(
layer.lora_b_stacked) == len(layer.output_slices))
if layer.lora_bias_stacked is not None:
assert layer.n_slices == len(layer.lora_bias_stacked)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffers: Optional[torch.Tensor] = layer.punica_wrapper.add_shrink(
buffers, x, layer.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffers = shrunk_buffers
buffers = tensor_model_parallel_all_gather(buffers)
lora_output: Optional[torch.Tensor] = layer.punica_wrapper.add_expand(
output,
buffers,
layer.lora_b_stacked,
layer.lora_bias_stacked,
layer.output_slices,
offset_start=0,
add_input=True)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
# their `lora_a` and `lora_b` have different sharding patterns. After
# completing the `lora_a` GEMM , a gather operation is performed.
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedColumnParallelLinearWithShardedLoRA(
MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
]
return lora_a
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
"""
Differs from QKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = tp_rank * shard_size
lora_a = lora_a[:, start_idx:start_idx + shard_size]
return lora_a
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(cls, source_layer: nn.Module,
lora_config: LoRAConfig, packed_modules_list: list,
model_config: Optional[PretrainedConfig]) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
"""
Differs from MergedQKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: list[Union[torch.Tensor, None]]
) -> list[Union[torch.Tensor, None]]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
]
return lora_a
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[:, start_idx:end_idx]
return lora_b
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
if bias is None:
return bias
self.lora_bias_stacked = cast(tuple[torch.Tensor, ...],
self.lora_bias_stacked)
shard_size = self.lora_bias_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = bias[start_idx:end_idx]
return bias
def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1,
output.shape[-1]), output.shape
buffer = torch.zeros(
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffer: Optional[torch.Tensor] = self.punica_wrapper.add_shrink(
buffer, x, self.lora_a_stacked, 1.0)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
lora_output: Optional[torch.Tensor] = self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.lora_bias_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig],
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)

1285
lora/layers.py Normal file

File diff suppressed because it is too large Load Diff

199
lora/lora.py Normal file
View File

@@ -0,0 +1,199 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence as GenericSequence
from typing import Optional
import torch
import torch.types
from vllm.lora.peft_helper import PEFTHelper
from vllm.utils import is_pin_memory_available
class LoRALayerWeights:
"""LoRA weights for a layer composed of two low rank matrixes."""
def __init__(
self,
module_name: str,
rank: int,
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
embeddings_tensor: Optional[torch.Tensor] = None,
scaling: Optional[float] = None,
) -> None:
self.module_name = module_name
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.bias = bias
self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
else:
self.scaling = scaling
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return self
self.lora_b *= self.scaling
self.scaling = 1
return self
@property
def input_dim(self) -> int:
return self.lora_a.shape[0]
@property
def output_dim(self) -> int:
return self.lora_b.shape[1]
@property
def is_packed(self) -> bool:
return False
@property
def extra_vocab_size(self) -> int:
return self.embeddings_tensor.shape[
0] if self.embeddings_tensor is not None else 0
@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
embeddings_tensor: Optional[torch.Tensor] = None,
) -> "LoRALayerWeights":
return cls(module_name, peft_helper.r, peft_helper.lora_alpha, None,
None, None, embeddings_tensor,
peft_helper.vllm_lora_scaling_factor)
@classmethod
def create_dummy_lora_weights(
cls,
module_name: str,
input_dim: int,
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.types.Device,
embeddings_tensor_dim: Optional[int] = None,
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros([input_dim, rank],
dtype=dtype,
device=device,
pin_memory=pin_memory)
lora_b = torch.zeros([rank, output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
if bias_enabled:
bias = torch.zeros([output_dim],
dtype=dtype,
device=device,
pin_memory=pin_memory)
else:
bias = None
embeddings_tensor = torch.rand(
10,
embeddings_tensor_dim,
dtype=dtype,
device=device,
pin_memory=pin_memory) if embeddings_tensor_dim else None
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
bias=bias,
embeddings_tensor=embeddings_tensor,
)
class PackedLoRALayerWeights(LoRALayerWeights):
"""LoRA used for packed layers (eg. qkv_proj)."""
def __init__(
self,
module_name: str,
rank: int,
lora_alphas: list[Optional[int]],
lora_a: list[Optional[torch.Tensor]],
lora_b: list[Optional[torch.Tensor]],
bias: Optional[list[Optional[torch.Tensor]]] = None,
scaling: Optional[list[float]] = None,
) -> None:
super().__init__(
module_name=module_name,
rank=rank,
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
bias=bias,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]
@classmethod
def pack(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
for lora in loras:
if lora is None:
continue
lora.optimize()
rank = first_lora.rank
module_name = first_lora.module_name
obj = cls(
module_name,
rank,
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
[lora.bias if lora is not None else None for lora in loras],
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
])
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self
@property
def input_dim(self) -> int:
raise NotImplementedError()
@property
def output_dim(self) -> int:
raise NotImplementedError()
@property
def is_packed(self) -> bool:
return True

818
lora/models.py Normal file
View File

@@ -0,0 +1,818 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import os
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any, Callable, Optional, Union
import regex as re
import safetensors.torch
import torch
from torch import nn
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
AdapterModelManager)
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
get_adapter, list_adapters,
remove_adapter, set_adapter_mapping)
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import (BaseLayerWithLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
LoRAMapping)
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
get_supported_lora_modules,
is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.model_executor.utils import get_packed_modules_mapping
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
_GLOBAL_LORA_ID = 0
@dataclass
class LongContextLoRAContext:
"""Context for lora adapters that support long context."""
# The scaling factors to support long context lora fine tuned models.
scaling_factors: list[float]
# dimension to apply rotary embedding.
rot_dim: int
# offsets to the sin_cos_cache for each lora_id loaded.
# This value is dynamically modified.
offsets_by_lora_id: dict[int, int] = field(default_factory=dict)
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
return _GLOBAL_LORA_ID
class LoRAModel(AdapterModel):
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: dict[str, LoRALayerWeights],
scaling_factor: Optional[float] = None,
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
scaling_factor: Scaling factor to support long context lora model.
None if the lora is not tuned for long context support.
"""
self.id = lora_model_id
# Scaling factor for long context lora model. None if it is not
# fine tuned for the long context.
self.scaling_factor = scaling_factor
assert (
lora_model_id
> 0), f"a valid lora id should be greater than 0, got {self.id}"
self.rank = rank
self.loras: dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
@property
def extra_vocab_size(self) -> int:
return max(lora.extra_vocab_size
for lora in self.loras.values()) if self.loras else 0
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
tensors: dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
embeddings: Optional[dict[str, torch.Tensor]] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
tensor_name, weights_mapper)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name),
None)
if embeddings_module:
lora_embeddings_tensor = embeddings[
embedding_modules[embeddings_module]].to(
device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = (
lora_embeddings_tensor.pin_memory())
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper, lora_embeddings_tensor)
if is_bias:
loras[module_name].bias = tensor.to(device=device,
dtype=dtype).t()
bias = tensor.to(device=device, dtype=dtype).t()
if pin_memory:
bias = bias.pin_memory()
loras[module_name].bias = bias
elif is_lora_a:
loras[module_name].lora_a = tensor.to(device=device,
dtype=dtype).t()
if pin_memory:
loras[module_name].lora_a = loras[
module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device,
dtype=dtype).t()
assert embedding_padding_modules is not None
if any(name in module_name
for name in embedding_padding_modules
) and target_embedding_padding is not None:
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[1]
addition = target_embedding_padding - lora_b.shape[1]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, addition))
if pin_memory:
loras[module_name].lora_b = loras[
module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id,
peft_helper.r,
loras,
scaling_factor=peft_helper.vllm_long_context_scaling_factor)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: list[str],
peft_helper: PEFTHelper,
*,
lora_model_id: Optional[int] = None,
device: str = "cuda",
dtype: Optional[torch.dtype] = None,
target_embedding_padding: Optional[int] = None,
embedding_modules: Optional[dict[str, str]] = None,
embedding_padding_modules: Optional[list[str]] = None,
weights_mapper: Optional[WeightsMapper] = None,
tensorizer_config_dict: Optional[dict] = None) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
peft_helper: Loaded lora configuration information.
lora_model_id: LoRA model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors")
new_embeddings_bin_file_path = os.path.join(lora_dir,
"new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[Union[list[str], str]] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
module_name, _, _ = parse_fine_tuned_lora_name(
lora_module, weights_mapper)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
if tensorizer_config_dict:
from tensorizer import TensorDeserializer
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
lora_tensor_path = os.path.join(tensorizer_config.tensorizer_dir,
"adapter_model.tensors")
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserializer_params)
check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path):
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it wont error and model will be trained with A, B
# loraified. C wont exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path):
# When a bin file is provided, we rely on config to find unexpected
# modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
peft_helper.target_modules, expected_lora_modules):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path,
map_location=device,
weights_only=True)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
embeddings = None
if os.path.isfile(new_embeddings_tensor_path):
embeddings = safetensors.torch.load_file(
new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path,
map_location=device,
weights_only=True)
return cls.from_lora_tensors(
lora_model_id=get_lora_id()
if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper)
class LoRAModelManager(AdapterModelManager):
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
self,
model: SupportsLoRA,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
):
"""Create a LoRAModelManager and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
"""
self.lora_config = lora_config
self.device = device
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: list[Optional[int]] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.long_lora_context: Optional[LongContextLoRAContext] = None
self.punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras)
# Scaling factor -> offset to the sin_cos_cache to it.
# Used for long context lora.
self.scaling_factor_to_offset: dict[float, int] = {}
super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
f" {self.model.__class__.__name__}."
if lora_config.long_lora_scaling_factors:
# We need to replace rotary emb layer to do batch computation
# for long lora.
self.supported_lora_modules.append("rotary_emb")
self.packed_modules_mapping = get_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping"))
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: Optional[LoRAMapping] = None
self._create_lora_modules()
self.model.lora_manager = self
self.adapter_type = 'LoRA'
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
@property
def adapter_slots(self) -> int:
return self.lora_slots
def activate_adapter(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_adapters:
return False
first_free_slot = next(
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None), None)
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_adapters[lora_id] = None
lora_model = self._registered_adapters[lora_id]
logger.debug("Activating LoRA. int id: %d, slot index: %d",
lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora:
module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias.
bias = module_lora.bias
if ((torch.is_tensor(bias) or
(isinstance(bias, Sequence) and any(b is not None
for b in bias)))
and not self.lora_config.bias_enabled):
module_lora.bias = None
raise ValueError(
f"Adapter bias cannot be used for {module_name}"
" without --enable-lora-bias.")
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
module_lora.embeddings_tensor,
module_lora.bias)
else:
module.reset_lora(index)
return True
def _deactivate_adapter(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def _set_long_lora_context(self, lora: LoRAModel):
if self.long_lora_context is None:
return
if lora.scaling_factor is None:
return
if (lora.scaling_factor not in self.scaling_factor_to_offset):
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
" has not been initialized.")
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
if offsets:
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_adapters[lora.id] = lora
self._set_long_lora_context(lora)
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in LoRAModelManager. "
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# update lora states
self.punica_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
self.long_lora_context,
)
def remove_all_adapters(self):
"""Remove all LoRAModels from the manager."""
self._registered_adapters.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_adapters.clear()
def _create_lora_modules(self):
for module_name, module in self.model.named_modules(
remove_duplicate=False):
if isinstance(module, PPMissingLayer):
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model, module_name,
from_layer(module, self.lora_slots, self.lora_config,
packed_moduled_lst, self.model.config))
# LinearScalingRotaryEmbeddingWithLoRA is used to handle
# long context lora. Register relevant metadata.
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLoRA):
self.long_lora_context = LongContextLoRAContext(
new_module.scaling_factors, new_module.rotary_dim)
self.scaling_factor_to_offset = \
new_module.scaling_factor_to_offset
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module = self.model.get_submodule(
"logits_processor")
new_module = replace_submodule(
self.model, "logits_processor",
from_layer_logits_processor(logits_processor_module,
module, self.lora_slots,
self.lora_config,
self.model.config))
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module,
BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA)
self.modules[module_name] = module
def create_dummy_lora(
self,
lora_id: int,
rank: int,
scaling_factor: Optional[float],
embedding_modules: Optional[dict[str, str]] = None) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {}, scaling_factor)
for module_name, module in self.model.named_modules():
bias_enabled = self.lora_config.bias_enabled
if (not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or isinstance(module, LinearScalingRotaryEmbeddingWithLoRA)
or self._filter_unsupported_mm_module(module_name)):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (module.base_layer.org_vocab_size +
self.lora_config.lora_extra_vocab_size if
hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1])
output_dim = module.base_layer.embedding_dim if hasattr(
module.base_layer,
"embedding_dim") else module.base_layer.weight.shape[0]
embeddings_tensor_dim = (module.base_layer.embedding_dim if
hasattr(module.base_layer,
"embedding_dim") else
module.base_layer.weight.shape[1])
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked[0].dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim,
bias_enabled=bias_enabled)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked[0].shape[-1],
module.lora_b_stacked[0].shape[-2],
rank,
module.lora_a_stacked[0].dtype,
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras: list[Optional[LoRALayerWeights]] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
bias_enabled=bias_enabled,
)
lora.optimize()
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module),
module_name) or target_module == module_name
for target_module in self.supported_lora_modules)
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
prefix_lst = module_mapping.connector + module_mapping.tower_model
return any(
[module_name.startswith(prefix) for prefix in prefix_lst])
return False
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name, [])
# When replacements is less than or equal to 1, it indicates that this
# module is not a packed module.
if len(replacements) <= 1:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
prefix + "." + r if prefix else r for r in replacements
]
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras: list[Optional[LoRALayerWeights]] = []
replaced_module: set[str] = set()
has_replacement = False
for r in new_module_names:
lora = self._get_lora_layer_weights(lora_model, r)
replacement_loras.append(lora)
if lora:
has_replacement = True
replaced_module.add(r)
if not has_replacement:
continue
for i in range(len(replacement_loras)):
if replacement_loras[i]:
continue
replacement_loras[i] = None
# HACK Temporary solution for the pool model.
if self.is_pooling_model and not lora_model.check_lora_name(
module_name):
replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
module_name = replaced_module_name
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
def _get_lora_layer_weights(
self, lora_model: LoRAModel,
module_name: str) -> Optional[LoRALayerWeights]:
org_module_name = module_name
if self.is_pooling_model and not lora_model.check_lora_name(
module_name):
# If it's a pool model, and the layer name is not found,
# remove the prefix 'model.' and search again.
module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
org_module_name = module_name
logger.info_once(
"For the pool model, successfully loaded the LoRA weights "
"after removing the prefix 'model.'.")
return lora_model.get_lora(org_module_name)
def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter)
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", adapter.id, adapter.id,
adapter.scaling_factor)
return add_adapter(adapter, self._registered_adapters, self.capacity,
self._add_adapter)
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
self._set_adapter_mapping)
def remove_adapter(self, adapter_id: int) -> bool:
return remove_adapter(adapter_id, self._registered_adapters,
self.deactivate_adapter)
def list_adapters(self) -> dict[int, Any]:
return list_adapters(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> Optional[Any]:
return get_adapter(adapter_id, self._registered_adapters)
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
bool]):
super().__init__(capacity, deactivate_lora_fn)
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(self, model: nn.Module, max_num_seqs: int,
max_num_batched_tokens: int, vocab_size: int,
lora_config: LoRAConfig, device: torch.device):
super().__init__(model, max_num_seqs, max_num_batched_tokens,
vocab_size, lora_config, device)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter)
self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter)
def list_adapters(self) -> dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_adapters.cache)
def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
logger.debug(
"Adding lora. Model id: %d, "
"int id: %d, "
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
if lora.id not in self._registered_adapters:
self._add_adapter(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(lora.id)
was_added = False
return was_added
def activate_adapter(
self,
lora_id: int,
) -> bool:
if lora_id not in self._active_adapters and len(
self._active_adapters) >= self.lora_slots:
self._active_adapters.remove_oldest()
result = super().activate_adapter(lora_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(lora_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True
def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_adapters.pin(lora_id)
except ValueError as err:
raise ValueError("Pinning failed. "
f"LoRA {lora_id} is not registered.") from err
def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_adapters:
# move lora to gpu if not already active
self.activate_adapter(lora_id)
self._active_adapters.pin(lora_id)
def create_lora_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not hasattr(model, "packed_modules_mapping"):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
device=device,
**kwargs)
return lora_manager

0
lora/ops/__init__.py Normal file
View File

View File

@@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.torch_ops.lora_ops import bgmv_expand # noqa: F401
from vllm.lora.ops.torch_ops.lora_ops import (bgmv_expand_slice, bgmv_shrink,
sgmv_expand, sgmv_expand_slice,
sgmv_shrink)
__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink",
]

View File

@@ -0,0 +1,119 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
def sgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
add_inputs)
def bgmv_expand(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
# LoRA adapter and model may add different amounts of padding to output
common_len = min(outputs.shape[1], output_tensor.shape[1])
if add_inputs:
output_tensor[:, :common_len] += outputs[:limit, :common_len]
else:
output_tensor[:, :common_len] = outputs[:limit, :common_len]
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
scaling)
def bgmv_shrink(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
output_tensor[:, :outputs.shape[1]] = scaling * outputs[:]
def sgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False):
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
seq_len_tensor)
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
slice_offset, slice_size, add_inputs)
def bgmv_expand_slice(inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True):
selected_loras = lora_b_weights[lora_indices_tensor].to(
dtype=output_tensor.dtype)
inputs = inputs.to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
if add_inputs:
output_tensor[:, slice_offset:slice_offset + slice_size] += outputs[:]
else:
output_tensor[:, slice_offset:slice_offset + slice_size] = outputs[:]

View File

@@ -0,0 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
__all__ = [
"lora_expand",
"lora_shrink",
"LoRAKernelMeta",
]

View File

@@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utilities for Punica kernel construction.
"""
from vllm.triton_utils import tl, triton
@triton.jit
def mm_k(a_ptr, b_ptr, ak_stride, bk_stride, offset_k, K: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
tiled_a = tl.load(a_ptr)
tiled_b = tl.load(b_ptr)
else:
tiled_a = tl.load(a_ptr,
mask=offset_k[None, :]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
tiled_b = tl.load(b_ptr,
mask=offset_k[:, None]
< K - k * (BLOCK_K * SPLIT_K),
other=0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * SPLIT_K * ak_stride
b_ptr += BLOCK_K * SPLIT_K * bk_stride
return accumulator
@triton.jit
def do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SAME_STRIDE: tl.constexpr,
SLICE_NUM: tl.constexpr,
EVEN_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
ADD_INPUTS: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE:
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
# Identify the input_ptr and lora_ptr from slice_id.
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty))
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride)
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride)
# Compute the block matrix product.
SPLIT_K = 1
accumulator = mm_k(a_ptr, b_ptr, input_d2_stride, cur_lora_d2_stride,
offset_k, K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K,
CAST_TYPE, cur_lora_ptr.dtype.element_ty)
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
offset_cm = tl.arange(0, BLOCK_M)
c_ptr = (out_ptr + ram[:, None] * output_d0_stride +
offset_cn[None, :] * output_d1_stride)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :]
< (cur_slice_start + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@triton.jit
def do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty))
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
a_ptr = (input_ptr + ram[:, None] * input_d0_stride +
offset_k[None, :] * input_d1_stride)
b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index +
rbn[None, :] * lora_d1_stride +
offset_k[:, None] * lora_d2_stride)
# Compute partial/complete block matrix product.
accumulator = mm_k(a_ptr, b_ptr, input_d1_stride, lora_d2_stride, offset_k,
K, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, False,
cur_lora_ptr.dtype.element_ty)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_cm = tl.arange(0, BLOCK_M)
cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr +
slice_id * output_d0_stride)
c_ptr = cur_out_ptr + ram[:, None] * output_d1_stride + offset_cn[
None, :] * output_d2_stride
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask)

View File

@@ -0,0 +1,290 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr
from vllm.utils import direct_register_custom_op
@triton.jit
def _lora_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS)
@torch.inference_mode()
def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: list[
torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.
Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (list[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor,
lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor,
same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start,
inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
BLOCK_M = 64
BLOCK_N = 128
BLOCK_K = 16
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
EVEN_K = K % BLOCK_K == 0 # type: ignore
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only a few input tokens require
# LoRA. This might not be the best in all cases.
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
)
_lora_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_expand_fake(
inputs: torch.Tensor,
lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_expand",
op_func=_lora_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
)
lora_expand = torch.ops.vllm.lora_expand
except AttributeError:
lora_expand = _lora_expand

View File

@@ -0,0 +1,148 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
LoRA kernels metadata preparation utilities.
"""
from dataclasses import dataclass
from typing import Union
import torch
@dataclass
class LoRAKernelMeta:
token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
@staticmethod
def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "LoRAKernelMeta":
token_lora_mapping = torch.empty(max_num_tokens,
dtype=torch.int32,
device=device)
token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens,
dtype=torch.int32,
device=device)
# +1 because "no-lora" is also a possibility
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
# is a possibility.
active_lora_ids = torch.empty(max_loras + 1,
dtype=torch.int32,
device=device)
# using running example, [3, 10, 5, 2] is a possibility.
num_tokens_per_lora = torch.zeros(max_loras + 1,
dtype=torch.int32,
device=device)
# +2 for this because, the first index is always 0.
# using running example, lora_token_start_loc
# is [0, 3, 13, 18, 20].
lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32,
device=device)
no_lora_flag_cpu = torch.tensor([False],
dtype=torch.bool,
device='cpu')
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
Prepare kernel metadata tensors for the current forward pass.
Args:
token_lora_tensor (torch.Tensor): Tensor containing lora indices
for each input token.
"""
self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping,
non_blocking=True)
# token_indices_sorted_by_lora_ids
_, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping,
stable=True)
# start gpu transfer
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
token_indices_sorted_by_lora_ids, non_blocking=True)
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping,
sorted=True,
return_counts=True)
self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids,
non_blocking=True)
self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True)
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_(
lora_token_start_loc, non_blocking=True)
def meta_args(
self, token_nums: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor, torch.Tensor]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the lora_shrink/lora_expand function call.
Args:
token_nums (int): Number of input tokens in the current forward
pass.
"""
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
)

View File

@@ -0,0 +1,244 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
import triton
import triton.language as tl
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr
from vllm.utils import direct_register_custom_op
@triton.jit
def _lora_shrink_kernel(input_ptr, lora_ptr, out_ptr, M, N, K,
token_indices_sorted_by_lora_ids, num_tokens_per_lora,
lora_token_start_loc, lora_ids, scaling,
input_d0_stride, input_d1_stride, lora_d0_stride,
lora_d1_stride, lora_d2_stride, output_d0_stride,
output_d1_stride, output_d2_stride,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr, SLICE_NUM: tl.constexpr):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (token_indices_sorted_by_lora_ids +
lora_m_indices_start + cta_m_offset)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM)
@torch.inference_mode()
def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: list[
torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (list[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
BLOCK_M = 32
BLOCK_N = 16
BLOCK_K = 256 if M < 128 else 32
SPLIT_K = 64 if M < 128 else 8
NUM_WARPS = 4
NUM_CTAS = 1
NUM_STAGES = 2
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only few of the input tokens
# require LoRA. This might not be the best in all cases.
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
)
_lora_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
NUM_SLICES,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_shrink",
op_func=_lora_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
)
lora_shrink = torch.ops.vllm.lora_shrink
except AttributeError:
lora_shrink = _lora_shrink

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
if values := _LORA_A_PTR_DICT.get(key):
return values
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
tensor_ptrs = []
for lora_a_weight in lora_a_weights:
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_a_weight.size(1) == 1
lora_a_weight = lora_a_weight.squeeze(dim=1)
else:
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_a_weight.is_contiguous()
tensor_ptrs.append(lora_a_weight.data_ptr())
lora_strides_d0.append(lora_a_weight.stride(0))
lora_strides_d1.append(lora_a_weight.stride(1))
lora_strides_d2.append(lora_a_weight.stride(2))
if len(lora_a_weights) > 1:
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
else:
lora_ptr_tensor = lora_a_weights[0]
if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1
or len(set(lora_strides_d2)) > 1):
raise ValueError("All LoRA weights must have the same stride.")
_LORA_A_PTR_DICT[key] = (
lora_ptr_tensor,
lora_strides_d0[0],
lora_strides_d1[0],
lora_strides_d2[0],
)
return _LORA_A_PTR_DICT.get(key)
def _get_lora_b_ptr(lora_weights: list[torch.Tensor], offset_start: int,
device: torch.device):
"""
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if values := _LORA_B_PTR_DICT.get(key):
return values
slice_offset_lst = []
tensor_ptrs = []
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
hidden_sizes = []
slice_offset = offset_start
for lora_b_weight in lora_weights:
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weight.size(1) == 1
lora_b_weight = lora_b_weight.squeeze(dim=1)
else:
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weight.is_contiguous()
tensor_ptrs.append(lora_b_weight.data_ptr())
lora_strides_d0.append(lora_b_weight.stride(0))
lora_strides_d1.append(lora_b_weight.stride(1))
lora_strides_d2.append(lora_b_weight.stride(2))
slice_offset_lst.append(slice_offset)
slice_offset += lora_b_weight.size(1)
hidden_sizes.append(lora_b_weight.size(1))
if len(lora_weights) > 1:
# note these are device tensors
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device)
slice_start_tensor = torch.tensor(slice_offset_lst, device=device)
else:
slice_start_tensor = slice_offset_lst[0]
lora_ptr_tensor = lora_b_weight[0]
# If each lora has the same stride, there's no need to use a
# tensor for storage.
if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and
len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1:
lora_strides_d0_tensor = lora_strides_d0[0]
lora_strides_d1_tensor = lora_strides_d1[0]
lora_strides_d2_tensor = lora_strides_d2[0]
hidden_sizes_tensor = hidden_sizes[0]
same_stride = True
else:
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
same_stride = False
# MAX_N is the maximum hidden size among all the lora_b weights
MAX_N = max(hidden_sizes)
_LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor,
lora_strides_d0_tensor, lora_strides_d1_tensor,
lora_strides_d2_tensor, hidden_sizes_tensor,
same_stride, MAX_N)
return _LORA_B_PTR_DICT.get(key)

View File

@@ -0,0 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.xla_ops.lora_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink)
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
import torch_xla.core.xla_builder as xb
from torch.library import impl
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
@jax.jit
def bgmv_jax(inputs, loras, idxs):
return jnp.einsum(
"td,tX,Xld->tl",
inputs,
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
loras,
)
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor,
idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
if output_tensor.shape[1] > outputs.shape[1]:
outputs = F.pad(outputs,
(0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
if add_inputs:
return output_tensor + outputs[:limit, :output_tensor.shape[1]]
else:
return outputs[:limit, :output_tensor.shape[1]]
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): (Unused) output tensor (placeholder).
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
scaling (float, optional): Scalar multiplier applied to the output.
"""
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights,
lora_indices_tensor)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
outputs = F.pad(
outputs,
(
slice_offset,
output_tensor.shape[1] - (slice_offset + slice_size),
0,
0,
),
)
if add_inputs:
return output_tensor + outputs
else:
return outputs

136
lora/peft_helper.py Normal file
View File

@@ -0,0 +1,136 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
import json
import math
import os
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal, Optional, Union
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
logger = init_logger(__name__)
@dataclass
class PEFTHelper:
"""
A helper class for PEFT configurations, specifically designed for LoRA.
This class handles configuration validation, compatibility checks for
various LoRA implementations.
"""
# Required fields
r: int
lora_alpha: int
target_modules: Union[list[str], str]
bias: Literal["none", "all", "lora_only"] = field(default="none")
modules_to_save: Optional[list[str]] = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False)
# long context lora field
context_length: int = field(default=0)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: Optional[int] = field(default=False)
vllm_long_context_scaling_factor: Optional[float] = field(default=None)
def _validate_features(self) -> list[str]:
"""
Check if there are any unsupported LoRA features.
"""
error_msg = []
if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")
return error_msg
def __post_init__(self):
if self.use_rslora:
logger.info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
if self.context_length:
if self.vllm_max_position_embeddings is None:
self.vllm_max_position_embeddings = self.context_length
self.vllm_long_context_scaling_factor = float(
math.ceil(self.context_length /
self.vllm_max_position_embeddings))
@classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
# Get all field information from the class
class_fields = {f.name: f for f in fields(cls)}
# Check for required fields
required_fields = {
name
for name, f in class_fields.items()
if f.default is MISSING and f.default_factory is MISSING
}
# Identify any missing required fields
missing_fields = required_fields - set(config_dict.keys())
if missing_fields:
raise ValueError(
f"Missing required configuration fields: {missing_fields}")
# Filter out fields that aren't defined in the class
filtered_dict = {
k: v
for k, v in config_dict.items() if k in class_fields
}
return cls(**filtered_dict)
@classmethod
def from_local_dir(
cls,
lora_path: str,
max_position_embeddings: Optional[int],
tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper":
lora_config_path = os.path.join(lora_path, "adapter_config.json")
if tensorizer_config_dict:
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
from tensorizer.stream_io import open_stream
lora_config_path = os.path.join(tensorizer_config.lora_dir,
"adapter_config.json")
with open_stream(lora_config_path,
mode="rb",
**tensorizer_args.stream_params) as f:
config = json.load(f)
logger.info("Successfully deserialized LoRA config from %s",
tensorizer_config.lora_dir)
else:
with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
return cls.from_dict(config)
def validate_legal(self, lora_config: LoRAConfig) -> None:
"""
Validates the LoRA configuration settings against application
constraints and requirements.
"""
error_msg = self._validate_features()
if self.r > lora_config.max_lora_rank:
error_msg.append(
f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}.")
if self.bias != "none" and not lora_config.bias_enabled:
error_msg.append(
"Adapter bias cannot be used without bias_enabled.")
if error_msg:
raise ValueError(f"{' '.join(error_msg)}")

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
__all__ = [
"PunicaWrapperBase",
"get_punica_wrapper",
]

View File

@@ -0,0 +1,485 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Optional, Union
import torch
from .utils import compute_meta, convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
class PunicaWrapperABC(ABC):
"""
PunicaWrapper ABC.
"""
@abstractmethod
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs,
) -> None:
"""
Update the lora-related metadata
"""
raise NotImplementedError
@abstractmethod
def add_shrink(
self,
y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
"""
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> Optional[torch.Tensor]:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
"""
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> Optional[torch.Tensor]:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
and this layer only requires the expand operation.
"""
raise NotImplementedError
@abstractmethod
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> Optional[torch.Tensor]:
"""
Applicable to linear-related lora.
"""
raise NotImplementedError
@abstractmethod
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> Optional[torch.Tensor]:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
"""
raise NotImplementedError
class PunicaWrapperBase(PunicaWrapperABC):
"""
PunicaWrapperBase is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
self._token_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
self._sampler_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
self._embeddings_indices = torch.empty(2,
max_num_batched_tokens,
dtype=torch.long,
device=device)
self._long_lora_indices = torch.empty(max_num_batched_tokens,
dtype=torch.long,
device=device)
# 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self.indices_len: list[Optional[int]] = [None] * 5
# these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches,
dtype=torch.long,
device=device)
self._seq_lengths = torch.empty(max_batches,
dtype=torch.long,
device=device)
self._lora_indices_per_batch = torch.empty(max_batches,
dtype=torch.long,
device=device)
self.device: torch.device = device
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
self.device,
long_lora_context,
)
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def _update_prefill_metadata(self,
token_lora_tensor: torch.Tensor) -> None:
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, token_nums,
no_lora) = compute_meta(token_lora_tensor)
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
b_seq_start_tensor)
self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
lora_indices_tensor)
self.batch_size = batch_size
self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias
offset_left += slice
return output.view_as(org_output)
@property
def prefill_metadata(
self
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return (self._seq_start_locs[:self.batch_size],
self._seq_lengths[:self.batch_size],
self._lora_indices_per_batch[:self.batch_size],
self.batch_size, self.max_length, self.token_nums)
@property
def token_lora_indices(self) -> torch.Tensor:
"""
This property provides the lora indices corresponding to each token
in the batch. An index of -1 means no lora should be applied.
"""
token_lora_len = self.indices_len[0]
return self._token_lora_indices[:token_lora_len]
@property
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA.
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
@property
def long_lora_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for long context
lora, specifically for LinearScalingRotaryEmbeddingWithLoRA.
"""
long_lora_len = self.indices_len[4]
return self._long_lora_indices[:long_lora_len]
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metadata(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
@abstractmethod
def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_expand(self,
y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
offset = offset_start
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> Optional[torch.Tensor]:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> Optional[torch.Tensor]:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> Optional[torch.Tensor]:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError

View File

@@ -0,0 +1,349 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Callable, Optional, Union
import torch
from vllm.lora.ops.torch_ops import (bgmv_expand, bgmv_expand_slice,
bgmv_shrink, sgmv_expand,
sgmv_expand_slice, sgmv_shrink)
from .punica_base import PunicaWrapperBase
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperCPU(PunicaWrapperBase):
"""
PunicaWrapperCPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
#No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
y_slice_size, add_inputs)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (self._expand_slice_prefill
if self.is_prefill else
self._expand_slice_decode)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, scale: float):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (self._shrink_prefill
if self.is_prefill else self._shrink_decode)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx],
scale)
def add_expand(self,
y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (self._expand_prefill
if self.is_prefill else self._expand_decode)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = tuple(
torch.zeros(
(x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices)))
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer,
lora_b_stacked,
y,
self.sampler_indices,
add_inputs=True)
y = y.view_as(y_org)

View File

@@ -0,0 +1,290 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import TYPE_CHECKING, Optional, Union, final
import torch
import vllm.envs as envs
from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand,
lora_shrink)
from .punica_base import PunicaWrapperBase
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.models import LongContextLoRAContext
@final
class PunicaWrapperGPU(PunicaWrapperBase):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
self.max_loras = kwargs['max_loras']
self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_num_batched_tokens,
device=device)
# When cudagraph capture size is greater than max_num_seqs (max_batches,
# here), V0 captures the graph as if max_num_seqs is set to
# the capture size.
# V1 doesn't have this problem and always respects max_num_seqs.
max_num_prompts = (max_batches
if envs.VLLM_USE_V1 else max_num_batched_tokens)
self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras,
max_num_prompts,
device=device)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
**kwargs):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
vocab_size, extra_vocab_size,
long_lora_context)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
def add_shrink(self, y: torch.Tensor, x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor,
...], scale: float, **kwargs):
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
lora_shrink(
x,
lora_a_stacked,
y,
*self.token_mapping_meta.meta_args(x.size(0)),
scale,
)
def add_expand(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> None:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
if lora_bias_stacked is not None:
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
y.size(0))
self._apply_bias(token_lora_indices, y, output_slices,
lora_bias_stacked)
assert x.ndim == 3
assert x.size(0) == len(output_slices)
num_tokens = x.size(1) # first dimension is the num slices
lora_expand(
x,
lora_b_stacked,
y,
*self.token_mapping_meta.meta_args(num_tokens),
offset_start=offset_start,
add_inputs=True,
)
y = y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
lora_expand(
x.unsqueeze(dim=0),
(lora_b_stacked, ),
y,
*self.token_mapping_meta.meta_args(x.size(0)),
offset_start=0,
add_inputs=add_inputs,
)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
y.size(0))
y = self._apply_bias(token_lora_indices, y, output_slices,
lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r),
dtype=torch.float32,
device=x.device,
)
self.add_shrink(
buffer, # type: ignore
x,
lora_a_stacked,
scale,
**kwargs)
self.add_expand(
y,
buffer, # type: ignore
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]): Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
lora_shrink(x, [lora_a_stacked], buffer.unsqueeze(dim=0),
*self.prompt_mapping_meta.meta_args(x.size(0)), scale)
lora_expand(buffer.unsqueeze(dim=0), [lora_b_stacked],
y,
*self.prompt_mapping_meta.meta_args(buffer.size(0)),
add_inputs=True)
y = y.view_as(y_org)

View File

@@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional, Union, final
import torch
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
dispatch_bgmv_linear)
from .punica_base import PunicaWrapperBase
from .utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
@final
class PunicaWrapperHPU(PunicaWrapperBase):
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
# Increasing max_num_batched_tokens by 3x to handle increase in
# tensor size due to padding.
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
max_batches, device)
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(mapping, lora_index_to_id, max_loras, vocab_size,
extra_vocab_size, self.device, None)
# Updating each element in `long_lora_offsets` with `lora_offset` slows
# down perf in HPU due to a series of `strided_insert` ops during lazy
# graph accumulation. Hence HPU appends `lora_offset` to a list and
# converts it to a tensor only after it is ready.
if long_lora_context:
index_mapping_indices: list[int] = list(
mapping.index_mapping).copy()
long_lora_offsets: list[int] = []
for i in range(len(index_mapping_indices)):
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets.append(lora_offset)
long_lora_offsets_tensor = torch.tensor(long_lora_offsets,
device=self.device,
dtype=torch.long)
indices_len[-1] = long_lora_offsets_tensor.shape[-1]
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded)
self._embeddings_indices[:embeddings_indices.
shape[0], :embeddings_indices.shape[1]].copy_(
embeddings_indices)
if long_lora_offsets_tensor is not None:
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
long_lora_offsets_tensor)
else:
self._long_lora_indices.zero_()
self.indices_len[:] = indices_len
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> None:
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> None:
y_org = y
x = x.view(-1, x.shape[-1])
y = y.view(-1, y.shape[-1])
offset_left = 0
for slice_idx in range(len(output_slices)):
dispatch_bgmv_linear(
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> None:
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
y = y.view_as(y_org)
def add_shrink(
self,
y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> None:
raise NotImplementedError
def add_expand(
self,
y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
raise NotImplementedError

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import resolve_obj_by_qualname
from .punica_base import PunicaWrapperBase
logger = init_logger(__name__)
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
punica_wrapper_qualname = current_platform.get_punica_wrapper()
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
assert punica_wrapper is not None, \
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1])
return punica_wrapper

View File

@@ -0,0 +1,405 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING, Optional, Union
import torch
import torch.nn.functional as F
import torch_xla.core.xla_model as xm
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
from .punica_base import PunicaWrapperBase
class PunicaWrapperTPU(PunicaWrapperBase):
"""
PunicaWrapperTPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(self, max_num_batched_tokens: int, max_batches: int,
device: Union[torch.device, str], **kwargs):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches,
device)
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
# isn't supported by the TPU. So convert those tensors to int32.
# Not all of them are used by the TPU so only convert the useful ones.
self._token_lora_indices = self._token_lora_indices.to(
dtype=torch.int32)
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
self._sampler_indices_padded = self._sampler_indices_padded.to(
dtype=torch.int32)
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded,
True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._long_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch,
True)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
return self._embeddings_indices[:]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def shrink(
self,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
def expand(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor,
add_inputs: bool):
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x),
add_inputs)
def expand_slice(self, y: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, y_offset: int, y_slice_size: int,
add_inputs: bool) -> torch.Tensor:
return bgmv_expand_slice(x, w_t_all, y,
self._get_token_lora_indices(x), y_offset,
y_slice_size, add_inputs)
def add_shrink(self, y: Union[tuple[torch.Tensor, ...], torch.Tensor],
x: torch.Tensor, lora_a_stacked: tuple[torch.Tensor, ...],
scale: float, **kwargs) -> Optional[torch.Tensor]:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
lora_s = lora_a_stacked[slice_idx]
y_s = self.shrink(x, lora_s, scale)
y[slice_idx, :, :] = y_s # type: ignore[index]
return y
def add_expand(self,
y: torch.Tensor,
x: Union[tuple[torch.Tensor, ...], torch.Tensor],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs) -> torch.Tensor:
"""
Performs GEMM and bias addition for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
lora_bias_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]):
bias's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
if lora_bias_stacked is not None:
y = self._apply_bias(self._get_token_lora_indices(y), y,
output_slices, lora_bias_stacked)
for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice(y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs)
offset_left += output_slices[slice_idx]
return y.view_as(y_org)
def add_lora_embedding(self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs) -> torch.Tensor:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only needs the expand op
return self.expand(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
lora_bias_stacked: Optional[tuple[torch.Tensor, ...]],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: Optional[tuple[torch.Tensor, ...]] = None,
**kwargs) -> torch.Tensor:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)+lora_bias_stacked[i]
Args:
y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
lora_bias_stacked (Optional[tuple[torch.Tensor, ...]]): lora's bias.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self._get_token_lora_indices(y), y,
output_slices, lora_bias_stacked)
if buffer is None:
r = lora_b_stacked[0].size(-1)
T = x.size(0)
buffer = torch.zeros(
(len(output_slices), T, r),
dtype=x.dtype,
device=x.device,
)
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
return self.add_expand(y,
buffer,
lora_b_stacked,
None,
output_slices,
add_inputs=True,
**kwargs)
def add_lora_logits(self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
y = bgmv_expand(buffer,
lora_b_stacked,
y,
sampler_indices,
add_inputs=True)
return y.view_as(y_org)
def _apply_bias(
self,
indices: torch.Tensor,
output: torch.Tensor,
output_slices: tuple[int, ...],
lora_bias_stacked: tuple[Optional[torch.Tensor], ...],
):
"""Applies bias to output
Input shapes:
lora_bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)
offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = lora_bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias = torch.where(indices[:, None] == -1, 0, bias)
bias = F.pad(bias, (offset_left, output.shape[1] -
(offset_left + slice), 0, 0))
output += bias
offset_left += slice
return output.view_as(org_output)
# This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
long_lora_context: Optional["LongContextLoRAContext"] = None,
):
# Make sure we don't accidentally collect outside operations
xm.mark_step()
# Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we
# avoid having backend specific LoRAMapping classes?
mapping.prompt_mapping = self._pad_prompt_mapping(
mapping.prompt_mapping)
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_offsets_tensor,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
"cpu",
long_lora_context,
)
self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape,
dims=1).to(self.device)
self._sampler_indices = self._pad_to_shape(sampler_indices,
self._sampler_indices.shape,
dims=1).to(self.device)
self._sampler_indices_padded = self._pad_to_shape(
sampler_indices_padded, self._sampler_indices_padded.shape,
dims=1).to(self.device)
self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape,
dims=2).to(self.device)
if long_lora_offsets_tensor is not None:
self._long_lora_indices = self._pad_to_shape(
long_lora_offsets_tensor,
self._long_lora_indices.shape,
dims=1).to(self.device)
else:
zeroed = torch.zeros_like(self._long_lora_indices.cpu(),
dtype=torch.int32)
self._long_lora_indices = zeroed.to(self.device)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self,
token_lora_tensor: torch.Tensor) -> None:
self.batch_size = 1
self._lora_indices_per_batch[:self.
batch_size] = token_lora_tensor[:self.
batch_size]
def _pad_prompt_mapping(
self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
num_reqs = len(prompt_mapping)
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
# import
MIN_NUM_SEQS = 8
padded_num_reqs = max(2**math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
pad_len = padded_num_reqs - num_reqs
padding = [-1] * pad_len
return tuple(list(prompt_mapping) + padding)
def _pad_to_shape(self, src, target_shape, dims=1):
if dims == 1:
pad_len = target_shape[0] - src.shape[0]
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
else:
pad_rows = target_shape[0] - src.shape[0]
pad_cols = target_shape[1] - src.shape[1]
return F.pad(src, (0, pad_cols, 0, pad_rows),
value=0).to(torch.int32)

View File

@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Optional, Union
import torch
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from vllm.lora.models import LongContextLoRAContext
def compute_meta(
token_lora_tensor: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
batch_size, max_length, token_nums, no_lora)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: list[Optional[int]],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
long_lora_context: Optional["LongContextLoRAContext"] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor], list[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
long_lora_context: Passed if there are long context lora in a batch.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indicies. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indicies, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
long_lora_indices: Tensor of shape [batch_size] mapping
requests to RoPE offsets and rot dims for long LoRAs.
None if long context lora doesn't exist.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices, long_lora_indices).
"""
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
long_lora_offsets: Optional[torch.Tensor] = None
if long_lora_context:
long_lora_offsets = torch.zeros(len(index_mapping_indices),
device=device,
dtype=torch.long)
prompt_mapping: list[int] = [
lora_index_to_id.index(x) if x > 0 else -1
for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0 else -1)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
if long_lora_context:
assert long_lora_offsets is not None
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
index_mapping_indices[i], 0)
long_lora_offsets[i] = lora_offset
indices_list: list[Union[list[int], torch.Tensor]] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
if long_lora_context:
assert long_lora_offsets is not None
indices_list.append(long_lora_offsets)
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(prompt_mapping,
dtype=torch.long,
device=device)
embeddings_indices = torch.stack([
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
])
embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1,
embeddings_indices)
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded = torch.where(sampler_indices_padded == -1,
max_loras - 1, sampler_indices_padded)
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
sampler_indices_padded * len(sampler_indices_padded))
long_lora_indices = None
long_lora_indices_len: Optional[int] = None
if long_lora_context:
long_lora_indices = indices[3]
long_lora_indices_len = long_lora_indices.shape[-1]
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
if long_lora_indices_len is not None:
indices_len.append(long_lora_indices_len)
else:
# If long_lora doesn't exist,append None
indices_len.append(None)
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
long_lora_indices,
indices_len,
)

99
lora/request.py Normal file
View File

@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
from typing import Optional
import msgspec
from vllm.adapter_commons.request import AdapterRequest
class LoRARequest(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
"""
Request for a LoRA adapter.
Note that this class should be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
__metaclass__ = AdapterRequest
lora_name: str
lora_int_id: int
lora_path: str = ""
lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = msgspec.field(default=None)
tensorizer_config_dict: Optional[dict] = None
def __post_init__(self):
if self.lora_local_path:
warnings.warn(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'lora_path' instead.",
DeprecationWarning,
stacklevel=2)
if not self.lora_path:
self.lora_path = self.lora_local_path or ""
# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"
@property
def adapter_id(self):
return self.lora_int_id
@property
def name(self):
return self.lora_name
@property
def path(self):
return self.lora_path
@property
def local_path(self):
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
return self.lora_path
@local_path.setter
def local_path(self, value):
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
self.lora_path = value
def __eq__(self, value: object) -> bool:
"""
Overrides the equality method to compare LoRARequest
instances based on lora_name. This allows for identification
and comparison lora adapter across engines.
"""
return isinstance(value,
self.__class__) and self.lora_name == value.lora_name
def __hash__(self) -> int:
"""
Overrides the hash method to hash LoRARequest instances
based on lora_name. This ensures that LoRARequest instances
can be used in hash-based collections such as sets and dictionaries,
identified by their names across engines.
"""
return hash(self.lora_name)

85
lora/resolver.py Normal file
View File

@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field
from typing import Optional
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
class LoRAResolver(ABC):
"""Base class for LoRA adapter resolvers.
This class defines the interface for resolving and fetching LoRA adapters.
Implementations of this class should handle the logic for locating and
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
etc.).
"""
@abstractmethod
async def resolve_lora(self, base_model_name: str,
lora_name: str) -> Optional[LoRARequest]:
"""Abstract method to resolve and fetch a LoRA model adapter.
Implements logic to locate and download LoRA adapter based on the name.
Implementations might fetch from a blob storage or other sources.
Args:
base_model_name: The name/identifier of the base model to resolve.
lora_name: The name/identifier of the LoRA model to resolve.
Returns:
Optional[LoRARequest]: The resolved LoRA model information, or None
if the LoRA model cannot be found.
"""
pass
@dataclass
class _LoRAResolverRegistry:
resolvers: dict[str, LoRAResolver] = field(default_factory=dict)
def get_supported_resolvers(self) -> Set[str]:
"""Get all registered resolver names."""
return self.resolvers.keys()
def register_resolver(
self,
resolver_name: str,
resolver: LoRAResolver,
) -> None:
"""Register a LoRA resolver.
Args:
resolver_name: Name to register the resolver under.
resolver: The LoRA resolver instance to register.
"""
if resolver_name in self.resolvers:
logger.warning(
"LoRA resolver %s is already registered, and will be "
"overwritten by the new resolver instance %s.", resolver_name,
resolver)
self.resolvers[resolver_name] = resolver
def get_resolver(self, resolver_name: str) -> LoRAResolver:
"""Get a registered resolver instance by name.
Args:
resolver_name: Name of the resolver to get.
Returns:
The resolver instance.
Raises:
KeyError: If the resolver is not found in the registry.
"""
if resolver_name not in self.resolvers:
raise KeyError(
f"LoRA resolver '{resolver_name}' not found. "
f"Available resolvers: {list(self.resolvers.keys())}")
return self.resolvers[resolver_name]
LoRAResolverRegistry = _LoRAResolverRegistry()

240
lora/utils.py Normal file
View File

@@ -0,0 +1,240 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import Optional, Union
import huggingface_hub
import regex as re
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, RepositoryNotFoundError)
from torch import nn
from transformers import PretrainedConfig
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.fully_sharded_layers import (
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA, QKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
VocabParallelEmbeddingWithLoRA)
from vllm.model_executor.layers.linear import LinearBase
# yapf: enable
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
LinearScalingRotaryEmbeddingWithLoRA,
}
def from_layer(layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
if lora_cls.can_replace_layer(source_layer=layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config):
instance_layer = lora_cls(layer)
instance_layer.create_lora_weights(max_loras, lora_config,
model_config)
return instance_layer
return layer
def from_layer_logits_processor(
layer: LogitsProcessor,
lm_head: ParallelLMHead,
max_loras: int,
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device,
lm_head.get_sharded_to_full_mapping())
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
def replace_submodule(model: nn.Module, module_name: str,
new_module: nn.Module) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def parse_fine_tuned_lora_name(
name: str,
weights_mapper: Optional[WeightsMapper] = None
) -> tuple[str, bool, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
is_bias whether the tensor is lora bias.
"""
# LoRA weight qualified name usually starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if name.startswith("base_model.model."):
name = name.replace("base_model.model.", "")
name = weights_mapper._map_name(name) if weights_mapper else name
# recover the prefix `base_model.model.`
name = "base_model.model." + name
else:
name = weights_mapper._map_name(name) if weights_mapper else name
# In some situations, we may not start with `base_model.model.`.
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
# we should keep the prefix intact.
start_index = 2 if name.startswith("base_model.model.") else 0
parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A"
or parts[-2] == "lora_B"):
new_name = ".".join(parts[start_index:-2])
return new_name, parts[-2] == "lora_A", False
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[start_index:-1])
return new_name, parts[-1] == "lora_embedding_A", False
if parts[-1] == "bias":
new_name = ".".join(parts[start_index:-2])
return new_name, False, True
raise ValueError(f"{name} is unsupported LoRA weight")
def is_regex_target_modules(load_modules: Union[str, list[str]],
expected_lora_modules: list[str]) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def is_valid_regex(pattern):
try:
re.compile(pattern)
return True
except re.error:
return False
def is_subset(sub_list, full_list):
return set(sub_list).issubset(set(full_list))
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if not isinstance(load_modules, str):
return False
if is_valid_regex(load_modules):
match = re.search(r"\((.*?)\)\$?$", load_modules)
if match:
suffix = match.group(1).split("|")
return is_subset(suffix, expected_lora_modules)
return False
def get_supported_lora_modules(model: nn.Module) -> list[str]:
"""
In vLLM, all linear layers support LoRA.
"""
supported_lora_modules: set[str] = set()
# step1: traverse the model to get all the linear subfixes.
for name, module in model.named_modules():
if isinstance(module, (LinearBase, )):
supported_lora_modules.add(name.split(".")[-1])
# step 2: get the embedding modules if the model's mbedding_modules
# is not empty.
if model.embedding_modules:
for name in model.embedding_modules:
supported_lora_modules.add(name)
return list(supported_lora_modules)
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.
Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.
Returns:
str: The resolved absolute local path to the lora model.
"""
# Check if the path is an absolute path. Return it no matter exists or not.
if os.path.isabs(lora_path):
return lora_path
# If the path starts with ~, expand the user home directory.
if lora_path.startswith('~'):
return os.path.expanduser(lora_path)
# Check if the expanded relative path exists locally.
if os.path.exists(lora_path):
return os.path.abspath(lora_path)
# If the path does not exist locally, assume it's a Hugging Face repo.
try:
local_snapshot_path = huggingface_hub.snapshot_download(
repo_id=lora_path)
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
HFValidationError):
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
logger.exception("Error downloading the HuggingFace model")
return lora_path
return local_snapshot_path

259
lora/worker_manager.py Normal file
View File

@@ -0,0 +1,259 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any, Literal, Optional, Union
import torch
from vllm.adapter_commons.utils import (add_adapter_worker,
apply_adapters_worker,
list_adapters_worker,
set_active_adapters_worker)
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
logger = init_logger(__name__)
class WorkerLoRAManager(AbstractWorkerManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_manager_cls: type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
embedding_modules: dict[str, str],
embedding_padding_modules: list[str],
lora_model_cls: type[LoRAModel] = LoRAModel,
max_position_embeddings: Optional[int] = None,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
self.max_num_seqs = max_num_seqs
self.max_num_batched_tokens = max_num_batched_tokens
self.vocab_size = vocab_size
self.lora_config = lora_config
self.max_position_embeddings = max_position_embeddings
super().__init__(device)
# Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@property
def is_enabled(self) -> bool:
return True
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
lora_manager = create_lora_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
lora_manager_cls=self._manager_cls,
)
self._adapter_manager = lora_manager
return lora_manager.model
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
try:
supported_lora_modules = (
self._adapter_manager.supported_lora_modules)
packed_modules_mapping = (
self._adapter_manager.packed_modules_mapping)
expected_lora_modules: list[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(
packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(
lora_path, self.max_position_embeddings,
lora_request.tensorizer_config_dict)
# Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails.
peft_helper.validate_legal(self.lora_config)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = None
if (hasattr(model, "hf_to_vllm_mapper")
and model.hf_to_vllm_mapper is not None):
hf_to_vllm_mapper = model.hf_to_vllm_mapper
lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size +
self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
weights_mapper=hf_to_vllm_mapper)
except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
# offline mode)
# - No local adapter files found at `lora_request.lora_path`
# For NotFoundError
raise ValueError(
f"Loading lora {lora_request.lora_name} failed: No adapter "
f"found for {lora_request.lora_path}") from e
except Exception as e:
# For BadRequestError
raise e
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{self.lora_config.lora_extra_vocab_size}.")
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_adapters():
return False
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(
lora_request.lora_int_id)
else:
dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, 1, self.embedding_modules)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._adapter_manager.add_adapter(dummy_lora)
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: set[Any],
mapping: Optional[Any]) -> None:
set_active_adapters_worker(requests, mapping, self._apply_adapters,
self._adapter_manager.set_adapter_mapping)
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
apply_adapters_worker(adapter_requests, self.list_adapters,
self._adapter_manager.adapter_slots,
self.remove_adapter, self.add_adapter)
def add_adapter(self, adapter_request: Any) -> bool:
return add_adapter_worker(adapter_request, self.list_adapters,
self._load_adapter,
self._adapter_manager.add_adapter,
self._adapter_manager.activate_adapter)
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> set[int]:
return list_adapters_worker(self._adapter_manager.list_adapters)
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
lora_manager = create_lora_manager(
model,
lora_manager_cls=self._manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._adapter_manager = lora_manager
return lora_manager.model
def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests if lora_request
}
if len(loras_map) > self._adapter_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._adapter_manager.lora_slots}).")
for lora in loras_map.values():
self.add_adapter(lora)
def add_adapter(self, lora_request: LoRARequest) -> bool:
# Note that this method is not thread-safe. It may be invoked multiple
# times for the same adapter when using multiple API servers.
# This is ok because it's currently only called from
# the single-threaded core engine loop.
if lora_request.lora_int_id not in self.list_adapters():
# Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters.
# This may cause the # of loaded lora adapters to very temporarily
# exceed `--max-cpu-loras`.
lora = self._load_adapter(lora_request)
# Loading succeeded, now check if we will exceed cache capacity and
# evict if the oldest adapter if so
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
assert isinstance(self._adapter_manager,
LRUCacheLoRAModelManager)
self._adapter_manager.remove_oldest_adapter()
# Then add the new adapter to the cache
loaded = self._adapter_manager.add_adapter(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = self._adapter_manager.get_adapter(
lora_request.lora_int_id) is not None
self._adapter_manager.activate_adapter(lora_request.lora_int_id)
return loaded