This commit is contained in:
root
2026-03-05 18:06:10 +08:00
commit 809cecae09
2569 changed files with 478204 additions and 0 deletions

0
lora/__init__.py Normal file
View File

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

41
lora/layers/__init__.py Normal file
View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.layers.column_parallel_linear import (
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.utils import LoRAMapping
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
__all__ = [
"BaseLayerWithLoRA",
"VocabParallelEmbeddingWithLoRA",
"LogitsProcessorWithLoRA",
"ColumnParallelLinearWithLoRA",
"ColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearWithLoRA",
"MergedColumnParallelLinearWithShardedLoRA",
"MergedQKVParallelLinearWithLoRA",
"MergedQKVParallelLinearWithShardedLoRA",
"QKVParallelLinearWithLoRA",
"QKVParallelLinearWithShardedLoRA",
"RowParallelLinearWithLoRA",
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"FusedMoEWithLoRA",
]

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

67
lora/layers/base.py Normal file
View File

@@ -0,0 +1,67 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
if TYPE_CHECKING:
from vllm.lora.punica_wrapper import PunicaWrapperBase
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(
self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
punica_wrapper,
):
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError

164
lora/layers/base_linear.py Normal file
View File

@@ -0,0 +1,164 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
from .utils import _get_lora_device
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LinearBase):
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer)
self.output_slices: tuple[int, ...]
self.output_size: int
self.n_slices: int
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
self.lora_config = lora_config
#
if isinstance(self.base_layer, ReplicatedLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, ColumnParallelLinear):
lora_a_out_size = (
lora_config.max_lora_rank
if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size)
)
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, RowParallelLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = (
self.output_size
if not lora_config.fully_sharded_loras
else divide(self.output_size, self.tp_size)
)
else:
raise NotImplementedError
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_a_out_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_b_out_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.output_slices = (self.lora_b_stacked[0].shape[2],)
def reset_lora(self, index: int):
for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert (
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
)
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
)
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
# In Transformers modeling backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
)
if not current_platform.can_update_inplace():
output = lora_output
return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
# HQQ marlin
elif hasattr(self.base_layer, "W_q"):
return self.base_layer.W_q
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> torch.Tensor | None:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None

View File

@@ -0,0 +1,578 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
)
from vllm.platforms import current_platform
from .base_linear import BaseLinearLayerWithLoRA
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (
layer.n_slices
== len(layer.lora_a_stacked)
== len(layer.lora_b_stacked)
== len(layer.output_slices)
)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink(
buffers, x, layer.lora_a_stacked, 1.0
)
if not current_platform.can_update_inplace():
buffers = shrunk_buffers
buffers = tensor_model_parallel_all_gather(buffers)
lora_output: torch.Tensor | None = layer.punica_wrapper.add_expand(
output,
buffers,
layer.lora_b_stacked,
layer.output_slices,
offset_start=0,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
There are two types for the `base_layer`:
1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
"""
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__(base_layer)
# The base_layer type is ColumnParallelLinear or
# MergedColumnParallelLinear, their weight sharding logic is
# inconsistent when TP is greater than 1.
self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear
self.output_size = self.base_layer.output_size_per_partition
# There is only one LoRA layer
self.n_slices = 1
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
# Applicable to cases where the base_layer is
# MergedColumnParallelLinear.
if self.is_merged_col_linear:
shard_size = self.output_size // 2
offset = lora_b.shape[0] // 2
left_weight = lora_b[
self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
]
right_weight = lora_b[
offset + self.tp_rank * shard_size : offset
+ (self.tp_rank + 1) * shard_size,
:,
]
lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
shard_size = self.output_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :]
return lora_b
def forward(
self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
"""Forward of ColumnParallelLinear
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
# Matrix multiply.
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.base_layer.return_bias:
return output
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is ColumnParallelLinear or (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 1
)
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (e.g. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
Both slices must have the same size.
"""
def __init__(
self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
) -> None:
super().__init__(base_layer)
# There are two LoRA layers
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes
self.output_slices = tuple(
divide(output_size, self.tp_size) for output_size in output_sizes
)
self.n_slices = len(self.output_slices)
self.output_ids = (self.tp_rank,) * self.n_slices
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""
The main reason for overriding this function is to enhance code
maintainability.
"""
self.lora_config = lora_config
lora_a_output_size_per_partition = (
lora_config.max_lora_rank
if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size)
)
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)
for output_size in self.output_slices
)
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
return lora_a
def slice_lora_b(
self, lora_b: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
sliced_lora_b = [None] * self.n_slices
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)
):
if (lora_b_i := lora_b[i]) is not None:
sliced_lora_b[i] = lora_b_i[
shard_size * shard_id : shard_size * (shard_id + 1), :
]
return sliced_lora_b
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
self.lora_a_stacked[i][
index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
].copy_(lora_a_i, non_blocking=True)
if (lora_b_i := lora_b[i]) is not None:
self.lora_b_stacked[i][
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
].copy_(lora_b_i, non_blocking=True)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 2
)
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
self.q_proj_total_size = (
self.base_layer.total_num_heads * self.base_layer.head_size
)
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
self.kv_proj_shard_size = (
self.base_layer.num_kv_heads * self.base_layer.head_size
)
self.kv_proj_total_size = (
self.base_layer.total_num_kv_heads * self.base_layer.head_size
)
# There is only one LoRA layer
self.n_slices = 1
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
* (self.q_shard_id + 1),
:,
]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
:,
]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
:,
]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
This means we have 3 LoRAs, each applied to one slice of the layer.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
# There are three LoRA layer.
self.n_slices = len(self.base_layer.output_sizes)
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
self.kv_proj_shard_size = (
self.base_layer.num_kv_heads * self.base_layer.head_size
)
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.output_ids = (
self.q_shard_id,
self.kv_shard_id,
self.kv_shard_id,
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""
The main reason for overloading this function is to handle inconsistent
weight dimensions in qkv lora.
"""
super().create_lora_weights(max_loras, lora_config, model_config)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
# These following layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
# their `lora_a` and `lora_b` have different sharding patterns. After
# completing the `lora_a` GEMM , a gather operation is performed.
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx : start_idx + shard_size, :]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[0] is not None
else None,
lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[1] is not None
else None,
]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
"""
Differs from QKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx : start_idx + shard_size, :]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
"""
Differs from MergedQKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
if lora_a[0] is not None
else None,
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
if lora_a[1] is not None
else None,
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
if lora_a[2] is not None
else None,
]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)

472
lora/layers/fused_moe.py Normal file
View File

@@ -0,0 +1,472 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
modular_marlin_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
modular_triton_fused_moe,
try_get_optimal_moe_config,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = base_layer.w2_weight.device
self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
normalized_config = {}
for key, value in config.items():
if key.islower():
if key.startswith("block_"):
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
else:
normalized_key = key.upper()
else:
normalized_key = key
normalized_config[normalized_key] = value
return normalized_config
def _get_lora_moe_configs(
self,
op_prefix: str,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
num_slices: int,
M: int,
layer: FusedMoE,
top_k: int,
config_dtype: str,
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=lora_a_stacked.shape[0],
batch=M,
hidden_size=lora_a_stacked.shape[-1],
rank=lora_a_stacked.shape[-2],
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked.shape[-2],
)
expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand",
max_loras=lora_a_stacked.shape[0],
batch=M,
hidden_size=lora_a_stacked.shape[-1],
rank=lora_a_stacked.shape[-2],
num_slices=num_slices,
moe_intermediate_size=lora_b_stacked.shape[-2],
)
else: # fall back to the default config
get_config_func = functools.partial(
try_get_optimal_moe_config,
layer.w13_weight.size(),
layer.w2_weight.size(),
top_k,
config_dtype,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
shrink_config = get_config_func(M)
expand_config = get_config_func(M)
shrink_config = self._normalize_keys(shrink_config)
expand_config = self._normalize_keys(expand_config)
return shrink_config, expand_config
def _inject_lora_into_fused_moe(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
self.base_layer.ensure_moe_quant_config_init()
quant_config = self.base_layer.quant_method.moe_quant_config
m_fused_moe_fn = (
modular_triton_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
if not quant_config.use_mxfp4_w4a16
else modular_marlin_fused_moe(
quant_config, shared_experts=self.base_layer.shared_experts
)
)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
]
result = func(*args, **kwargs)
return result
return wrapper
def act_decorator(layer, func):
def wrapper(*args, **kwargs):
_, output, input = args
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
expert_map = moe_state_dict["expert_map"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13",
lora_a_stacked=self.w1_lora_a_stacked,
lora_b_stacked=self.w1_lora_b_stacked,
num_slices=2,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
# get the block size of m from customized config or default config
max_loras = self.w1_lora_a_stacked.shape[0]
(
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.punica_wrapper.moe_lora_align_block_size(
curr_topk_ids,
num_tokens,
shrink_config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts,
max_loras,
self.adapter_enabled,
expert_map,
)
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
moe_state_dict["expert_ids_lora"] = expert_ids_lora
moe_state_dict["num_tokens_post_padded_lora"] = (
num_tokens_post_padded_lora
)
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
w13_lora_a_stacked,
w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
)
result = func(*args, **kwargs)
moe_state_dict["intermediate_cache2"] = output
return result
return wrapper
def moe_sum_decorator(layer, func):
def wrapper(*args, **kwargs):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
lora_a_stacked=self.w2_lora_a_stacked,
lora_b_stacked=self.w2_lora_b_stacked,
num_slices=1,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
expert_ids_lora = moe_state_dict["expert_ids_lora"]
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
max_loras = self.w1_lora_a_stacked.shape[0]
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
[self.w2_lora_a_stacked],
[self.w2_lora_b_stacked],
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
True,
)
result = func(*args, **kwargs)
return result
return wrapper
fused_experts = m_fused_moe_fn.fused_experts
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
)
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
self.base_layer.quant_method = FusedMoEModularMethod(
self.base_layer.quant_method, m_fused_moe_fn
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self.w1_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w1_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w2_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w2_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w3_lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.w3_lora_b_stacked = torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts):
# gate_proj,down_proj,up_proj
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
self.w1_lora_a_stacked[index] = 0
self.w1_lora_b_stacked[index] = 0
self.w3_lora_a_stacked[index] = 0
self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0
self.adapter_enabled[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None,
):
"""Overwrites lora tensors at index."""
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
w3_lora_a = lora_a[eid * 3 + 2]
w1_lora_b = lora_b[eid * 3]
w2_lora_b = lora_b[eid * 3 + 1]
w3_lora_b = lora_b[eid * 3 + 2]
# Handle the case of adding LoRA to only a subset of experts
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
continue
if self.tp_size > 1:
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
self.w1_lora_a_stacked[
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
].copy_(w1_lora_a, non_blocking=True)
self.w3_lora_a_stacked[
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
].copy_(w3_lora_a, non_blocking=True)
self.w2_lora_b_stacked[
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
].copy_(w2_lora_b, non_blocking=True)
self.w1_lora_b_stacked[
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
].copy_(w1_lora_b, non_blocking=True)
self.w3_lora_b_stacked[
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
].copy_(w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
].copy_(w2_lora_a, non_blocking=True)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# return type(source_layer) is FusedMoE
return isinstance(source_layer, FusedMoE)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def _shared_experts(self):
return self.base_layer._shared_experts
@property
def quant_method(self):
return self.base_layer.quant_method
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router

View File

@@ -0,0 +1,252 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
"""
def __init__(
self,
base_layer: LogitsProcessor,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
sharded_to_full_mapping: list[int] | None,
) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.sharded_to_full_mapping = sharded_to_full_mapping
@property
def logits_as_input(self):
return self.base_layer.logits_as_input
@property
def vocab_size(self):
return self.base_layer.vocab_size
@property
def scale(self):
return self.base_layer.scale
@property
def soft_cap(self):
return self.base_layer.soft_cap
@property
def use_all_gather(self):
return self.base_layer.use_all_gather
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@property
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor
@property
def should_modify_greedy_probs_inplace(self):
return self.base_layer.should_modify_greedy_probs_inplace
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
# TODO: Verify if this condition can be further relaxed
if 32000 < self.base_layer.vocab_size > 257024:
raise ValueError(
"When using LoRA, vocab size must be 32000 >= vocab_size <= 257024"
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
# Pad for kernel compatibility
math.ceil(
self.base_layer.vocab_size / lora_config.lora_vocab_padding_size
)
* lora_config.lora_vocab_padding_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.embeddings_tensors = torch.full(
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
fill_value=float("-inf"),
dtype=self.dtype,
device=self.device,
)
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
)
else:
self.sharded_to_full_mapping_gpu = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = float("-inf")
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
)
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
: embeddings_tensor.shape[0],
: embeddings_tensor.shape[1],
] = embeddings_tensor
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: torch.Tensor | None = None,
) -> torch.Tensor | None:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
# Gather logits for TP
logits = self.base_layer._gather_logits(logits)
if logits is None:
return None
if self.sharded_to_full_mapping_gpu is not None:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],
hidden_states.shape[0],
dtype=self.embeddings_tensors.dtype,
device=self.embeddings_tensors.device,
)
torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1])
neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype)
lora_logits[-1] = neg_inf
lora_logits = lora_logits.mT
indices_padded = self.punica_wrapper.sampler_indices_padded
if current_platform.is_tpu() or current_platform.is_xpu():
indices_padded = indices_padded[: logits.size(0)]
lora_logits = (
lora_logits.reshape(
lora_logits.shape[0] * lora_logits.shape[1],
lora_logits.shape[2],
)
.index_select(0, indices_padded)
.nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)
)
logits[
:,
self.base_layer.org_vocab_size : self.base_layer.org_vocab_size
+ lora_logits.shape[1],
] = lora_logits
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
)
if not current_platform.can_update_inplace():
logits = lora_output
# Remove paddings in vocab (if any).
logits = logits[:, : self.base_layer.vocab_size]
return logits
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
# Special handling for the LogitsProcessor.
return False

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from .base_linear import BaseLinearLayerWithLoRA
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__(
base_layer,
)
# To ensure interface compatibility, set to 1 always.
self.output_size = self.base_layer.output_size
self.n_slices = 1
def forward(
self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
"""Forward of ReplicatedLinearWithLoRA
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
# Matrix multiply.
output = self.apply(input_, bias)
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
if not self.base_layer.return_bias:
return output
return output, output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is ReplicatedLinear
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora a if splitting for tensor parallelism."""
return lora_a
def slice_lora_b(
self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora b if splitting with tensor parallelism."""
return lora_b

View File

@@ -0,0 +1,181 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform
from .base_linear import BaseLinearLayerWithLoRA
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__(base_layer)
# reset input_size
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
# There is only one LoRA layer.
self.n_slices = 1
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
shard_size = self.input_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_a = lora_a[:, start_idx:end_idx]
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b
def forward(
self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size
)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply(input_parallel)
if self.base_layer.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.base_layer.skip_bias_add:
output = (
output_ + self.base_layer.bias
if self.base_layer.bias is not None
else output_
)
output_bias = None
else:
output = output_
output_bias = self.base_layer.bias
if not self.base_layer.return_bias:
return output
return output, output_bias
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is RowParallelLinear
# The following layer is based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :]
return lora_b
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffer = torch.zeros(
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
buffer, x, self.lora_a_stacked, 1.0
)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer
if self.tp_size > 1:
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)

65
lora/layers/utils.py Normal file
View File

@@ -0,0 +1,65 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
import torch
import torch.nn as nn
@dataclass
class LoRAMapping:
index_mapping: tuple[int, ...]
prompt_mapping: tuple[int, ...]
is_prefill: bool = False
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
# Compressed Tensor
elif hasattr(base_layer, "weight_packed"):
return base_layer.weight_packed.device
# GPTQ/AWQ
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# HQQ marlin
elif hasattr(base_layer, "W_q"):
return base_layer.W_q.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _not_fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
return can_replace(*args, **kwargs) and condition
return dec
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (
can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
)
return dec

View File

@@ -0,0 +1,166 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: tuple[int, int] | None
self.embeddings_weights: torch.Tensor | None
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
if self.base_layer.num_added_embeddings_per_partition > 0:
# We can start adding lora weights
self.embeddings_weights = self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition # noqa: E501
+ self.base_layer.num_added_embeddings_per_partition
]
self.embeddings_slice = (
self.base_layer.shard_indices.added_vocab_start_index
- self.base_layer.org_vocab_size,
self.base_layer.shard_indices.added_vocab_end_index
- self.base_layer.org_vocab_size,
)
self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition :
].fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
self.embeddings_tensors = torch.zeros(
(
max_loras,
lora_config.lora_extra_vocab_size,
self.base_layer.embedding_dim,
),
dtype=self.base_layer.weight.dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.embedding_dim,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
self.embeddings_tensors[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None,
):
self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True
)
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
if embeddings_tensor is not None:
self.embeddings_tensors[
index,
: embeddings_tensor.shape[0],
: embeddings_tensor.shape[1],
].copy_(embeddings_tensor, non_blocking=True)
if self.embeddings_slice is not None:
# TODO(yard1): Optimize this copy, we don't need to copy
# everything, just the modified part
embeddings = self.embeddings_tensors.view(
self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1],
self.embeddings_tensors.shape[2],
)[self.embeddings_slice[0] : self.embeddings_slice[1]]
assert self.embeddings_weights is not None
self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings)
def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0)
# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
full_lora_a_embeddings = F.embedding(
x + indices_1,
self.lora_a_stacked_2d,
)
full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask))
full_output_org = full_output
if full_output.ndim == 3:
full_output = full_output.view(
full_output.shape[0] * full_output.shape[1], -1
)
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1],
-1,
)
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_embedding(
full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True
)
if not current_platform.can_update_inplace():
full_output = lora_output
return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None,
) -> bool:
return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):
return self.base_layer.weight

198
lora/lora_weights.py Normal file
View File

@@ -0,0 +1,198 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence as GenericSequence
from typing import Optional
import torch
import torch.types
from vllm.lora.peft_helper import PEFTHelper
from vllm.utils.platform_utils import is_pin_memory_available
class LoRALayerWeights:
"""LoRA weights for a layer composed of two low rank matrixes."""
def __init__(
self,
module_name: str,
rank: int,
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
embeddings_tensor: torch.Tensor | None = None,
scaling: float | None = None,
) -> None:
self.module_name = module_name
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
self.embeddings_tensor = embeddings_tensor
if scaling is None:
self.scaling = self.lora_alpha / self.rank
else:
self.scaling = scaling
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return self
self.lora_b *= self.scaling
self.scaling = 1
return self
@property
def input_dim(self) -> int:
return self.lora_a.shape[1]
@property
def output_dim(self) -> int:
return self.lora_b.shape[0]
@property
def is_packed(self) -> bool:
return False
@property
def extra_vocab_size(self) -> int:
return (
self.embeddings_tensor.shape[0] if self.embeddings_tensor is not None else 0
)
@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
embeddings_tensor: torch.Tensor | None = None,
) -> "LoRALayerWeights":
# lora_a and lora_b are set to None for config-based construction
return cls(
module_name,
peft_helper.r,
peft_helper.lora_alpha,
None,
None,
embeddings_tensor,
peft_helper.vllm_lora_scaling_factor,
)
@classmethod
def create_dummy_lora_weights(
cls,
module_name: str,
input_dim: int,
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.types.Device,
embeddings_tensor_dim: int | None = None,
) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros(
[rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory
)
lora_b = torch.zeros(
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
)
embeddings_tensor = (
torch.rand(
10,
embeddings_tensor_dim,
dtype=dtype,
device=device,
pin_memory=pin_memory,
)
if embeddings_tensor_dim
else None
)
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
embeddings_tensor=embeddings_tensor,
)
class PackedLoRALayerWeights(LoRALayerWeights):
"""LoRA used for packed layers (eg. qkv_proj)."""
def __init__(
self,
module_name: str,
rank: int,
lora_alphas: list[int | None],
lora_a: list[torch.Tensor | None],
lora_b: list[torch.Tensor | None],
scaling: list[float] | None = None,
) -> None:
super().__init__(
module_name=module_name,
rank=rank,
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling, # type: ignore
embeddings_tensor=None,
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]
@classmethod
def pack(
cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
for lora in loras:
if lora is None:
continue
lora.optimize()
rank = first_lora.rank
module_name = first_lora.module_name
obj = cls(
module_name,
rank,
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
],
)
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self
@property
def input_dim(self) -> int:
raise NotImplementedError()
@property
def output_dim(self) -> int:
raise NotImplementedError()
@property
def is_packed(self) -> bool:
return True

890
lora/models.py Normal file
View File

@@ -0,0 +1,890 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import os
from collections.abc import Callable
from typing import TypeVar
import regex as re
import safetensors.torch
import torch
from torch import nn
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import BaseLayerWithLoRA, FusedMoEWithLoRA, LoRAMapping
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.punica_wrapper import get_punica_wrapper
from vllm.lora.utils import (
from_layer,
from_layer_logits_processor,
get_supported_lora_modules,
is_regex_target_modules,
parse_fine_tuned_lora_name,
process_packed_modules_mapping,
replace_submodule,
)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.utils.cache import LRUCache
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
T = TypeVar("T")
class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: int, value: T | None):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
_GLOBAL_LORA_ID = 0
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
return _GLOBAL_LORA_ID
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: dict[str, LoRALayerWeights],
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
"""
self.id = lora_model_id
assert lora_model_id > 0, (
f"a valid lora id should be greater than 0, got {self.id}"
)
self.rank = rank
self.loras: dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
@property
def extra_vocab_size(self) -> int:
return (
max(lora.extra_vocab_size for lora in self.loras.values())
if self.loras
else 0
)
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
# (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
tensors: dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: torch.dtype | None = None,
embeddings: dict[str, torch.Tensor] | None = None,
target_embedding_padding: int | None = None,
embedding_modules: dict[str, str] | None = None,
embedding_padding_modules: list[str] | None = None,
weights_mapper: WeightsMapper | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper
)
if module_name not in loras:
lora_embeddings_tensor = None
if embeddings:
assert embedding_modules is not None
embeddings_module = next(
(k for k in embedding_modules if k in module_name), None
)
if embeddings_module:
lora_embeddings_tensor = embeddings[
embedding_modules[embeddings_module]
].to(device=device, dtype=dtype)
if pin_memory:
lora_embeddings_tensor = lora_embeddings_tensor.pin_memory()
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper, lora_embeddings_tensor
)
if is_lora_a:
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
if pin_memory:
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
assert embedding_padding_modules is not None
if (
any(name in module_name for name in embedding_padding_modules)
and target_embedding_padding is not None
):
lora_b = loras[module_name].lora_b
assert target_embedding_padding >= lora_b.shape[0]
addition = target_embedding_padding - lora_b.shape[0]
loras[module_name].lora_b = torch.nn.functional.pad(
lora_b, (0, 0, 0, addition)
)
if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
for lora in loras.values():
lora.optimize()
return cls(lora_model_id, peft_helper.r, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: list[str],
peft_helper: PEFTHelper,
*,
lora_model_id: int | None = None,
device: str = "cuda",
dtype: torch.dtype | None = None,
target_embedding_padding: int | None = None,
embedding_modules: dict[str, str] | None = None,
embedding_padding_modules: list[str] | None = None,
weights_mapper: WeightsMapper | None = None,
tensorizer_config_dict: dict | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
peft_helper: Loaded lora configuration information.
lora_model_id: LoRA model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
Returns:
Loaded LoRA Model.
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
new_embeddings_tensor_path = os.path.join(
lora_dir, "new_embeddings.safetensors"
)
new_embeddings_bin_file_path = os.path.join(lora_dir, "new_embeddings.bin")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
# Case for expert lora weights
if ".experts" in module_name:
if not any(
module_name.endswith(ele) for ele in expected_lora_modules
):
unexpected_modules.append(module_name)
elif module_name.split(".")[-1] not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
if tensorizer_config_dict:
from tensorizer import TensorDeserializer
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
lora_tensor_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_model.tensors"
)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(
lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserialization_kwargs,
)
check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path):
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it wont error and model will be trained with A, B
# loraified. C wont exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
# When a bin/pt file is provided, we rely on config to find
# unexpected modules.
unexpected_modules = []
target_modules = peft_helper.target_modules
if not isinstance(target_modules, list):
target_modules = [target_modules]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules and not is_regex_target_modules(
peft_helper.target_modules, expected_lora_modules
):
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
lora_file_path = (
lora_bin_file_path
if os.path.isfile(lora_bin_file_path)
else lora_pt_file_path
)
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
embeddings = None
if os.path.isfile(new_embeddings_tensor_path):
embeddings = safetensors.torch.load_file(new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(
new_embeddings_bin_file_path, map_location=device, weights_only=True
)
return cls.from_lora_tensors(
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
embeddings=embeddings,
target_embedding_padding=target_embedding_padding,
embedding_modules=embedding_modules,
embedding_padding_modules=embedding_padding_modules,
weights_mapper=weights_mapper,
)
class LoRAModelManager:
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
self,
model: SupportsLoRA,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
):
"""Create a LoRAModelManager and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
"""
self.model: SupportsLoRA = model
self._registered_adapters: dict[int, LoRAModel] = {}
# Dict instead of a set for compatibility with LRUCache.
self._active_adapters: dict[int, None] = {}
self.adapter_type = "LoRA"
self.lora_config = lora_config
self.device = device
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
max_loras=self.lora_config.max_loras,
)
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in"
f" {self.model.__class__.__name__}."
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None
self._create_lora_modules()
self.model.lora_manager = self
def __len__(self) -> int:
return len(self._registered_adapters)
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
@property
def adapter_slots(self) -> int:
return self.lora_slots
def activate_adapter(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_adapters:
return False
first_free_slot = next(
(
(i, lora_id)
for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None
),
None,
)
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_adapters[lora_id] = None
lora_model = self._registered_adapters[lora_id]
logger.debug(
"Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
)
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora:
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if isinstance(module, FusedMoEWithLoRA) and torch.is_tensor(
module_lora.lora_a
):
# Handle FSDP file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
gate_up_proj_lora = self._get_lora_layer_weights(
lora_model, module_name + ".base_layer"
)
assert gate_up_proj_lora is not None
assert module_lora is not None
down_proj_lora = module_lora
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
num_experts, dim=-1
)
up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
num_experts, dim=-1
)
down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)
lora_a = []
lora_b = []
for i in range(num_experts):
lora_a.append(gate_proj_a[i])
lora_a.append(down_proj_a[i])
lora_a.append(up_proj_a[i])
lora_b.append(gate_proj_b[i])
lora_b.append(down_proj_b[i])
lora_b.append(up_proj_b[i])
module_lora.lora_a = lora_a
module_lora.lora_b = lora_b
module.set_lora(
index,
module_lora.lora_a,
module_lora.lora_b,
module_lora.embeddings_tensor,
)
else:
module.reset_lora(index)
return True
def _deactivate_adapter(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_adapters[lora.id] = lora
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in LoRAModelManager. "
"Use LRUCacheLoRAModelManager for pinning"
) # type: ignore
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# update lora states
self.punica_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
self.lora_config.lora_extra_vocab_size,
)
def remove_all_adapters(self):
"""Remove all LoRAModels from the manager."""
self._registered_adapters.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_adapters.clear()
def _create_lora_modules(self):
def _parent_module(module_name: str) -> str:
# module name is a dot separated name.
# for example:
# - given an input 'x.y.z' return 'x.y'
# - given an input 'x' return ''
return module_name.rpartition(".")[0]
for module_name, module in self.model.named_modules(remove_duplicate=False):
if isinstance(module, PPMissingLayer):
continue
if not self._match_target_modules(module_name):
continue
# A temporary approach for multimodal models to support LoRA
# TODO: Remove this restriction
if self._filter_unsupported_mm_module(module_name):
logger.warning(
"Regarding multimodal models, vLLM currently only supports "
"adding LoRA to language model, %s will be ignored.",
module_name,
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
new_module = replace_submodule(
self.model,
module_name,
from_layer(
module,
self.lora_slots,
self.lora_config,
packed_moduled_lst,
self.model.config,
),
)
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module_name = "logits_processor"
parent_module = _parent_module(module_name)
if parent_module:
logits_processor_module_name = (
f"{parent_module}.{logits_processor_module_name}"
)
logits_processor_module = self.model.get_submodule(
logits_processor_module_name
)
new_module = replace_submodule(
self.model,
logits_processor_module_name,
from_layer_logits_processor(
logits_processor_module,
module,
self.lora_slots,
self.lora_config,
self.model.config,
),
)
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(self.punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
f"Module {module_name} must be a BaseLayerWithLoRA instance,"
)
f" got {type(module)}"
self.modules[module_name] = module
def create_dummy_lora(
self,
lora_id: int,
rank: int,
embedding_modules: dict[str, str] | None = None,
) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
if (
not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or self._filter_unsupported_mm_module(module_name)
):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
input_dim = (
module.base_layer.org_vocab_size
+ self.lora_config.lora_extra_vocab_size
if hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1]
)
output_dim = (
module.base_layer.embedding_dim
if hasattr(module.base_layer, "embedding_dim")
else module.base_layer.weight.shape[0]
)
embeddings_tensor_dim = (
module.base_layer.embedding_dim
if hasattr(module.base_layer, "embedding_dim")
else module.base_layer.weight.shape[1]
)
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked[0].dtype,
"cpu",
embeddings_tensor_dim=embeddings_tensor_dim,
)
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked[0].shape[-1],
module.lora_b_stacked[0].shape[-2],
rank,
module.lora_a_stacked[0].dtype,
"cpu",
)
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras: list[LoRALayerWeights | None] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
)
subloras.append(lora)
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module), module_name
)
or target_module == module_name
for target_module in self.supported_lora_modules
)
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
"""
Regarding multimodal models, vLLM currently only supports adding LoRA to
language model. LoRA for other modules, such as the vision tower, will
be filtered out.
"""
if self.supports_mm:
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
prefix_lst = module_mapping.connector + module_mapping.tower_model
return any([module_name.startswith(prefix) for prefix in prefix_lst])
return False
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name, [])
# When replacements is less than or equal to 1, it indicates that this
# module is not a packed module.
if len(replacements) <= 1:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
prefix + "." + r if prefix else r for r in replacements
]
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras: list[LoRALayerWeights | None] = []
replaced_module: set[str] = set()
has_replacement = False
for r in new_module_names:
lora = self._get_lora_layer_weights(lora_model, r)
replacement_loras.append(lora)
if lora:
has_replacement = True
replaced_module.add(r)
if not has_replacement:
continue
for i in range(len(replacement_loras)):
if replacement_loras[i]:
continue
replacement_loras[i] = None
# HACK Temporary solution for the pool model.
if self.is_pooling_model and not lora_model.check_lora_name(module_name):
replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
module_name = replaced_module_name
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras
)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
def _get_lora_layer_weights(
self, lora_model: LoRAModel, module_name: str
) -> LoRALayerWeights | None:
org_module_name = module_name
if self.is_pooling_model and not lora_model.check_lora_name(module_name):
# If it's a pool model, and the layer name is not found,
# remove the prefix 'model.' and search again.
module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
org_module_name = module_name
logger.info_once(
"For the pool model, successfully loaded the LoRA weights "
"after removing the prefix 'model.'."
)
return lora_model.get_lora(org_module_name)
def deactivate_adapter(self, adapter_id: int) -> bool:
if adapter_id not in self._active_adapters:
return False
self._deactivate_adapter(adapter_id)
self._active_adapters.pop(adapter_id, None)
return True
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
if adapter.id in self._registered_adapters:
return False
if len(self._registered_adapters) >= self.capacity:
raise RuntimeError("No free adapter slots.")
self._add_adapter(adapter)
return True
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
if self._last_mapping != mapping:
self._set_adapter_mapping(mapping)
self._last_mapping = mapping
def remove_adapter(self, adapter_id: int) -> bool:
self.deactivate_adapter(adapter_id)
if adapter_id not in self._registered_adapters:
return False
self._registered_adapters.pop(adapter_id, None)
return True
def list_adapters(self) -> dict[int, LoRAModel]:
return dict(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> LoRAModel | None:
return self._registered_adapters.get(adapter_id)
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_lora_fn)
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
):
super().__init__(
model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config, device
)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter
)
self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter
)
def list_adapters(self) -> dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_adapters.cache)
def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
if lora.id not in self._registered_adapters:
self._add_adapter(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(lora.id)
was_added = False
return was_added
def activate_adapter(
self,
lora_id: int,
) -> bool:
if (
lora_id not in self._active_adapters
and len(self._active_adapters) >= self.lora_slots
):
self._active_adapters.remove_oldest()
result = super().activate_adapter(lora_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(lora_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True
def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_adapters.pin(lora_id)
except ValueError as err:
raise ValueError(
f"Pinning failed. LoRA {lora_id} is not registered."
) from err
def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_adapters:
# move lora to gpu if not already active
self.activate_adapter(lora_id)
self._active_adapters.pin(lora_id)
def create_lora_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs,
) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not isinstance(model, SupportsLoRA):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
device=device,
**kwargs,
)
return lora_manager

0
lora/ops/__init__.py Normal file
View File

Binary file not shown.

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.ipex_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try:
import intel_extension_for_pytorch as ipex
except ImportError as e:
raise e
def bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
ipex.llm.functional.bgmv_shrink(
inputs, lora_a_weights, output_tensor, lora_indices_tensor, scaling
)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
ipex.llm.functional.bgmv_expand(
inputs, lora_b_weights, output_tensor, lora_indices_tensor, add_inputs
)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
ipex.llm.functional.bgmv_expand_slice(
inputs,
lora_b_weights,
output_tensor,
lora_indices_tensor,
slice_offset,
slice_size,
add_inputs,
)

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.torch_ops.lora_ops import (
bgmv_expand, # noqa: F401
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink",
]

View File

@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
def sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
# LoRA adapter and model may add different amounts of padding to output
common_len = min(outputs.shape[1], output_tensor.shape[1])
if add_inputs:
output_tensor[:, :common_len] += outputs[:limit, :common_len]
else:
output_tensor[:, :common_len] = outputs[:limit, :common_len]
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling)
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
def sgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_expand_slice(
inputs,
lora_b_weights,
output_tensor,
exploded_indices,
slice_offset,
slice_size,
add_inputs,
)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
inputs = inputs.to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
if add_inputs:
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
else:
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]

View File

@@ -0,0 +1,60 @@
# Multi-LoRA Tuning
**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`.
Without this, the shrink/expand kernels will use default configurations.
## Tuning Process
Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from
[Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).
1. Define the searching space. Here is an example of searching space:
```python
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [32, 64, 128, 256]
num_warps_range = [4, 8]
num_stage_range = [2, 3, 4, 5]
num_ctas_range = [1]
split_k_range = [4, 8, 16, 32, 64]
```
2. Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.
For example, you can acquire the info by simply checking
[add_lora_linear](https://github.com/vllm-project/vllm/blob/main/vllm/lora/punica_wrapper/punica_gpu.py#L181):
```python
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
print(f"num_slices: {len(output_slices)}")
for i in range(len(output_slices)):
print(f"a{i} shape: {lora_a_stacked[i].shape}")
print(f"b{i} shape: {lora_b_stacked[i].shape}")
print("y_shape", y.shape)
```
3. Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space
by performing a grid search to find the optimal kernel configuration.
vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py)
can be used to search for configurations for different shapes.
## Config Files
### File Naming
| Kernel Type | File Name Template | Example |
|---------------------------|--------------------------------------------|---------------------------------------------|
| shrink | `{gpu_name}_SHRINK.json` | `NVIDIA_H200_SHRINK.json` |
| expand | `{gpu_name}_EXPAND_{add_input}.json` | `NVIDIA_H200_EXPAND_TRUE.json` |
| fused_moe_lora_w13_shrink | `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json` | `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json` |
| fused_moe_lora_w13_expand | `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json` | `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json` |
| fused_moe_lora_w2_shrink | `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json` | `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json` |
| fused_moe_lora_w2_expand | `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json` | `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json` |
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`.
### JSON Structure
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]`,
where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer.

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
fused_moe_lora,
fused_moe_lora_expand,
fused_moe_lora_shrink,
)
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
__all__ = [
"lora_expand",
"lora_shrink",
"LoRAKernelMeta",
"fused_moe_lora",
"fused_moe_lora_shrink",
"fused_moe_lora_expand",
]

Binary file not shown.

View File

@@ -0,0 +1,640 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
return ptr_tensor
tensor_ptrs = []
for lora_weight in lora_weights:
tensor_ptrs.append(lora_weight.data_ptr())
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
_LORA_PTR_DICT[key] = ptr_tensor
return _LORA_PTR_DICT.get(key)
@triton.jit(
do_not_specialize=[
"num_valid_tokens",
"EM",
"stride_tl",
"stride_el",
"slice_a_size",
"slice_c_size",
]
)
def _fused_moe_lora_kernel(
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
num_experts,
lora_ids,
adapter_enabled,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_bl,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_tl,
stride_el,
slice_a_size,
slice_c_size,
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
top_k: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
# USE_GDC: tl.constexpr,
# launch_pdl: tl.constexpr,
# IS_PRIMARY: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
moe_enabled = tl.load(adapter_enabled + lora_id)
if lora_id == -1 or moe_enabled == 0:
# Early exit for the no-lora case.
return
max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
pid_sk = pid % SPLIT_K
pid_m_n = pid // SPLIT_K
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
# get the expert_id to process curr shard
ind = lora_id * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
if expert_id == -1:
return
# get a_ptr,b_ptr,c_ptr
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
)
token_mask = offs_token < num_valid_tokens
# get a_ptrs,b_ptrs
a_ptrs = cur_a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = (
cur_b_ptr
+ lora_id * stride_bl
+ expert_id * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# pre-fetch lora weight
b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# if USE_GDC and not IS_PRIMARY:
# tl.extra.cuda.gdc_wait()
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0,
)
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
# if USE_GDC and IS_PRIMARY:
# # GDC launch dependents hints the runtime system to launch dependent kernels.
# tl.extra.cuda.gdc_launch_dependents()
accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
if SPLIT_K == 1:
tl.store(c_ptrs, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
@torch.inference_mode()
def _fused_moe_lora_shrink(
a_intermediate_cache1: torch.Tensor,
# (num_slices, num_tokens, top_k_num, max_lora_rank)
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
lora_a_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
## adding for kernel
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
mul_routed_weight: bool = False,
) -> None:
w1_lora_a_stacked = lora_a_stacked[0]
# use_gdc = supports_pdl(qcurr_hidden_states.device)
shrink_config = {
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
"SPLIT_K": split_k,
# "USE_GDC": use_gdc,
# "launch_pdl": use_gdc, # triton kernel metadata
}
b_ptr = _get_ptr(lora_a_stacked, device)
grid = lambda META: (
split_k
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
lora_a_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
b_ptr,
a_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2),
a_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
num_slice_a=1,
num_slice_c=num_slices,
top_k=1 if mul_routed_weight else top_k_num,
MUL_ROUTED_WEIGHT=False,
# IS_PRIMARY=True,
**shrink_config,
)
@torch.inference_mode()
def _fused_moe_lora_expand(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
## adding for kernel
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
max_lora_rank: int,
w1_output_dim_size: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
mul_routed_weight: bool = False,
) -> None:
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
N = w1_output_dim_size
w1_lora_b_stacked = lora_b_stacked[0]
a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
)
b_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, w1_output_dim_size),
dtype=output.dtype,
device=device,
)
# use_gdc = supports_pdl(a_intermediate_cache1.device)
expand_config = {
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
"SPLIT_K": split_k, # Set split_k = 1 for expand calls
# "USE_GDC": use_gdc,
# "launch_pdl": use_gdc, # triton kernel metadata
}
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
lora_b_stacked[0].shape[0],
)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
b_ptr,
b_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
N,
K,
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
w1_lora_b_stacked.stride(1),
w1_lora_b_stacked.stride(3),
w1_lora_b_stacked.stride(2),
b_intermediate_cache1.stride(2),
b_intermediate_cache1.stride(3),
sorted_token_ids.stride(0),
expert_ids.stride(0),
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=b_intermediate_cache1.numel() // num_slices,
num_slice_a=num_slices,
num_slice_c=num_slices,
top_k=1,
MUL_ROUTED_WEIGHT=mul_routed_weight,
# IS_PRIMARY=False,
**expand_config,
)
for i in range(num_slices):
output[:, :, i * N : (i + 1) * N] += b_intermediate_cache1[i]
@torch.inference_mode()
def _fused_moe_lora(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
lora_a_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, N, max_lora_rank,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,)
num_tokens_post_padded: torch.Tensor, # (max_loras, )
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
shrink_block_size_k: int,
shrink_group_size_m: int,
shrink_num_warps: int,
shrink_num_stages: int,
shrink_split_k: int,
expand_block_size_m: int,
expand_block_size_n: int,
expand_block_size_k: int,
expand_group_size_m: int,
expand_num_warps: int,
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert (
sorted_token_ids.dim()
== expert_ids.dim()
== topk_weights.dim()
== qcurr_hidden_states.dim()
== 2
)
assert (
sorted_token_ids.shape[0]
== expert_ids.shape[0]
== num_tokens_post_padded.shape[0]
)
assert len(lora_b_stacked) * lora_b_stacked[0].shape[-2] == output.shape[-1]
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked)
w1_lora_b_stacked = lora_b_stacked[0]
num_experts = lora_a_stacked[0].shape[1]
N = max_lora_rank
M = topk_weights.shape[0]
EM = sorted_token_ids.shape[1]
K = qcurr_hidden_states.shape[1]
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
a_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, max_lora_rank),
dtype=output.dtype,
device=device,
)
_fused_moe_lora_shrink(
a_intermediate_cache1,
qcurr_hidden_states,
lora_a_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_k_num,
lora_ids,
adapter_enabled,
## adding for kernel
device,
N,
M,
EM,
K,
num_tokens,
num_experts,
num_slices,
shrink_block_size_m,
shrink_block_size_n,
shrink_block_size_k,
shrink_group_size_m,
shrink_num_warps,
shrink_num_stages,
shrink_split_k,
mul_routed_weight,
)
_fused_moe_lora_expand(
output,
a_intermediate_cache1,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
top_k_num,
lora_ids,
adapter_enabled,
## adding for kernel
device,
N,
M,
EM,
K,
num_tokens,
num_experts,
num_slices,
max_lora_rank,
w1_output_dim_size,
expand_block_size_m,
expand_block_size_n,
expand_block_size_k,
expand_group_size_m,
expand_num_warps,
expand_num_stages,
expand_split_k,
mul_routed_weight,
)
def _fused_moe_lora_fake(
output: torch.Tensor,
qcurr_hidden_states: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
shrink_block_size_k: int,
shrink_group_size_m: int,
shrink_num_warps: int,
shrink_num_stages: int,
shrink_split_k: int,
expand_block_size_m: int,
expand_block_size_n: int,
expand_block_size_k: int,
expand_group_size_m: int,
expand_num_warps: int,
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
) -> None:
return
def _fused_moe_lora_shrink_fake(
a_intermediate_cache1: torch.Tensor,
qcurr_hidden_states: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
mul_routed_weight: bool = False,
) -> None:
return
def _fused_moe_lora_expand_fake(
output: torch.Tensor,
a_intermediate_cache1: torch.Tensor,
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
max_lora_rank: int,
w1_output_dim_size: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
mul_routed_weight: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="fused_moe_lora",
op_func=_fused_moe_lora,
mutates_args=["output"],
fake_impl=_fused_moe_lora_fake,
)
direct_register_custom_op(
op_name="fused_moe_lora_shrink",
op_func=_fused_moe_lora_shrink,
mutates_args=["a_intermediate_cache1"],
fake_impl=_fused_moe_lora_shrink_fake,
)
direct_register_custom_op(
op_name="fused_moe_lora_expand",
op_func=_fused_moe_lora_expand,
mutates_args=["output"],
fake_impl=_fused_moe_lora_expand_fake,
)
fused_moe_lora = torch.ops.vllm.fused_moe_lora
fused_moe_lora_shrink = torch.ops.vllm.fused_moe_lora_shrink
fused_moe_lora_expand = torch.ops.vllm.fused_moe_lora_expand
except AttributeError:
fused_moe_lora = _fused_moe_lora
fused_moe_lora_shrink = _fused_moe_lora_shrink
fused_moe_lora_expand = _fused_moe_lora_expand

View File

@@ -0,0 +1,364 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utilities for Punica kernel construction.
"""
from vllm.triton_utils import tl, triton
@triton.jit
def mm_k(
a_ptr,
b_ptr,
ak_stride,
bn_stride,
bk_stride,
offset_k,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
# USE_GDC: tl.constexpr,
base_k,
USE_STRIDE_LOAD: tl.constexpr,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bn_stride: N dimension stride of the B matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
USE_STRIDE_LOAD: Whether to use stride load for the B matrix.
base_k: Base offset along K dimension for current SPLIT_K group
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Step size along K for each iteration
STEP_K = BLOCK_K * SPLIT_K
# Total number of iterations (compile-time constant)
num_iters = tl.cdiv(K, STEP_K)
for k in range(num_iters):
# Current iteration's global K offset
iter_k = k * STEP_K + base_k
# Check if this iteration is completely valid (no masking needed)
block_end = iter_k + BLOCK_K
if EVEN_K:
# K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
# tiled_b = tl.load(b_ptr)
# if USE_GDC:
# tl.extra.cuda.gdc_wait()
if USE_STRIDE_LOAD:
tiled_b = tl.load(b_ptr, stride=bn_stride)
else:
tiled_b = tl.load(b_ptr)
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Check if we need element-wise masking
if iter_k >= K:
# Entire block out of range, skip
pass
elif block_end <= K:
# Entire block in range, no masking needed (fast path)
# tiled_b = tl.load(b_ptr)
# if USE_GDC:
# tl.extra.cuda.gdc_wait()
if USE_STRIDE_LOAD:
tiled_b = tl.load(b_ptr, stride=bn_stride)
else:
tiled_b = tl.load(b_ptr)
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Partial block, need masking (only last iteration)
k_offsets = tl.arange(0, BLOCK_K)
mask = iter_k + k_offsets < K
# tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
# if USE_GDC:
# tl.extra.cuda.gdc_wait()
if USE_STRIDE_LOAD:
tiled_b = tl.load(
b_ptr, stride=bn_stride, mask=mask[:, None], other=0.0
)
else:
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += STEP_K * ak_stride
b_ptr += STEP_K * bk_stride
return accumulator
@triton.jit
def do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SAME_STRIDE: tl.constexpr,
SLICE_NUM: tl.constexpr,
EVEN_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
ADD_INPUTS: tl.constexpr,
# USE_GDC: tl.constexpr,
USE_STRIDE_LOAD: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE:
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
# Identify the input_ptr and lora_ptr from slice_id.
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty)
)
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (
cur_input_ptr
+ ram[:, None] * input_d1_stride
+ offset_k[None, :] * input_d2_stride
)
b_ptr = (
cur_lora_ptr
+ cur_lora_d0_stride * lora_index
+ offset_k[:, None] * cur_lora_d2_stride
+ rbn[None, :] * cur_lora_d1_stride
)
# Compute the block matrix product.
SPLIT_K = 1
accumulator = mm_k(
a_ptr,
b_ptr,
input_d2_stride,
cur_lora_d1_stride,
cur_lora_d2_stride,
offset_k,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
CAST_TYPE,
cur_lora_ptr.dtype.element_ty,
# USE_GDC,
base_k=0,
USE_STRIDE_LOAD = USE_STRIDE_LOAD
)
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
offset_cm = tl.arange(0, BLOCK_M)
c_ptr = (
out_ptr
+ ram[:, None] * output_d0_stride
+ offset_cn[None, :] * output_d1_stride
)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < (cur_slice_start + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@triton.jit
def do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr,
# USE_GDC: tl.constexpr,
USE_STRIDE_LOAD: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty)
)
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
a_ptr = (
input_ptr + ram[:, None] * input_d0_stride + offset_k[None, :] * input_d1_stride
)
b_ptr = (
cur_lora_ptr
+ lora_d0_stride * lora_index
+ rbn[None, :] * lora_d1_stride
+ offset_k[:, None] * lora_d2_stride
)
# Compute partial/complete block matrix product.
accumulator = mm_k(
a_ptr,
b_ptr,
input_d1_stride,
lora_d1_stride,
lora_d2_stride,
offset_k,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
False,
cur_lora_ptr.dtype.element_ty,
# False, # USE_GDC is always False in shrink kernel
base_k=pid_sk * BLOCK_K,
USE_STRIDE_LOAD = USE_STRIDE_LOAD
)
# GDC launch dependents hints the runtime system to launch dependent kernels.
# if USE_GDC:
# tl.extra.cuda.gdc_launch_dependents()
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_cm = tl.arange(0, BLOCK_M)
cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * output_d0_stride
c_ptr = (
cur_out_ptr
+ ram[:, None] * output_d1_stride
+ offset_cn[None, :] * output_d2_stride
)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask, sem="relaxed")

View File

@@ -0,0 +1,336 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import os
from contextlib import contextmanager
import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
@contextmanager
def _temporary_env(var_name: str, value: str | None):
prev_value = os.environ.get(var_name)
if value is None:
os.environ.pop(var_name, None)
else:
os.environ[var_name] = value
try:
yield
finally:
if prev_value is None:
os.environ.pop(var_name, None)
else:
os.environ[var_name] = prev_value
@triton.jit
def _lora_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr,
# USE_GDC: tl.constexpr,
# launch_pdl: tl.constexpr,
USE_STRIDE_LOAD: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
# USE_GDC,
USE_STRIDE_LOAD,
)
@torch.inference_mode()
def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (list[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(
slice_start_tensor,
lora_ptr_tensor,
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
hidden_sizes_tensor,
same_stride,
MAX_N,
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
kernel_config = get_lora_op_configs(
op_type="expand",
max_loras=MAX_LORAS,
batch=M,
hidden_size=MAX_N,
rank=K,
num_slices=NUM_SLICES,
add_inputs=add_inputs,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_CTAS = kernel_config["num_ctas"]
NUM_STAGES = kernel_config["num_stages"]
EVEN_K = K % BLOCK_K == 0 # type: ignore
if same_stride:
use_stride_load = False
else:
elem_size = lora_b_weights[0].element_size()
use_stride_load = all(
(weight.stride(1) * elem_size) % 64 == 0 for weight in lora_b_weights
)
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only a few input tokens require
# LoRA. This might not be the best in all cases.
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
)
disable_store_stp = os.getenv("VLLM_LORA_DISABLE_STORE_STP", "1") == "1"
with _temporary_env("TRITON_DISABLE_STORE_STP",
"1" if disable_store_stp else None):
_lora_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
use_stride_load,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
)
return
def _lora_expand_fake(
inputs: torch.Tensor,
lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_expand",
op_func=_lora_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
)
lora_expand = torch.ops.vllm.lora_expand
except AttributeError:
lora_expand = _lora_expand

View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
LoRA kernels metadata preparation utilities.
"""
from dataclasses import dataclass
import torch
@dataclass
class LoRAKernelMeta:
token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
@staticmethod
def make(
max_loras: int, max_num_tokens: int, device: torch.device | str
) -> "LoRAKernelMeta":
token_lora_mapping = torch.empty(
max_num_tokens, dtype=torch.int32, device=device
)
token_indices_sorted_by_lora_ids = torch.empty(
max_num_tokens, dtype=torch.int32, device=device
)
# +1 because "no-lora" is also a possibility
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
# is a possibility.
active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, device=device)
# using running example, [3, 10, 5, 2] is a possibility.
num_tokens_per_lora = torch.zeros(
max_loras + 1, dtype=torch.int32, device=device
)
# +2 for this because, the first index is always 0.
# using running example, lora_token_start_loc
# is [0, 3, 13, 18, 20].
lora_token_start_loc = torch.zeros(
max_loras + 2, dtype=torch.int32, device=device
)
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu,
)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
Prepare kernel metadata tensors for the current forward pass.
Args:
token_lora_mapping (torch.Tensor): Tensor containing lora indices
for each input token.
"""
self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
self.token_lora_mapping[:num_tokens].copy_(
token_lora_mapping, non_blocking=True
)
# token_indices_sorted_by_lora_ids
_, token_indices_sorted_by_lora_ids = torch.sort(
token_lora_mapping, stable=True
)
# start gpu transfer
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
token_indices_sorted_by_lora_ids, non_blocking=True
)
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(
token_lora_mapping, sorted=True, return_counts=True
)
self.active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True)
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True
)
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_(
lora_token_start_loc, non_blocking=True
)
def meta_args(
self, token_nums: int
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the lora_shrink/lora_expand function call.
Args:
token_nums (int): Number of input tokens in the current forward
pass of the kernel.
"""
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
)

View File

@@ -0,0 +1,290 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import os
import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl
@triton.jit
def _lora_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
input_d0_stride,
input_d1_stride,
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
output_d0_stride,
output_d1_stride,
output_d2_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SLICE_NUM: tl.constexpr,
# USE_GDC: tl.constexpr,
# launch_pdl: tl.constexpr,
USE_STRIDE_LOAD: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m_n = pid_sk_m_n // SPLIT_K
num_pid_in_group = GROUP_SIZE_M * cta_n_num
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)
# Column-major ordering within groups for better cache reuse
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM,
# USE_GDC,
USE_STRIDE_LOAD,
)
@torch.inference_mode()
def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: list[torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (list[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
output_tensor.zero_()
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
_get_lora_a_ptr(lora_a_weights, inputs.device)
)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
kernel_config = get_lora_op_configs(
"shrink",
max_loras=MAX_LORAS,
batch=M,
hidden_size=K,
rank=N,
num_slices=NUM_SLICES,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
SPLIT_K = kernel_config["split_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_STAGES = kernel_config["num_stages"]
NUM_CTAS = kernel_config["num_ctas"]
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
use_stride_load = False
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only few of the input tokens
# require LoRA. This might not be the best in all cases.
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks exit early.
MAX_LORAS,
)
# use_gdc = supports_pdl(inputs.device)
_lora_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
GROUP_SIZE_M,
NUM_SLICES,
#use_gdc,
use_stride_load,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
# launch_pdl=use_gdc,
)
return
def _lora_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_shrink",
op_func=_lora_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
)
lora_shrink = torch.ops.vllm.lora_shrink
except AttributeError:
lora_shrink = _lora_shrink

View File

@@ -0,0 +1,362 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import json
from functools import lru_cache
from pathlib import Path
from typing import Any
import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
def _is_corex_backend() -> bool:
return getattr(torch, "corex", False) is True
def _get_corex_lora_op_config(
op_type: str,
batch: int,
hidden_size: int,
num_slices: int,
) -> dict[str, int | None] | None:
if op_type == "expand":
if batch <= 16:
block_m = 16
num_warps = 4
elif batch <= 32:
block_m = 32
num_warps = 4
else:
block_m = 64
num_warps = 8
return {
"block_m": block_m,
"block_n": 128,
"block_k": 16,
"num_warps": num_warps,
"num_ctas": 1,
"num_stages": 1,
"max_nreg": None,
}
if op_type == "shrink":
if batch <= 1024:
block_m, block_n, block_k = 32, 16, 64
split_k, num_warps, num_stages = 8, 2, 2
else:
if num_slices == 1:
block_m, block_n, block_k = 64, 16, 64
split_k, num_warps, num_stages = 4, 4, 1
else:
block_m, block_n, block_k = 128, 16, 128
num_warps, num_stages = 8, 1
split_k = 4 if hidden_size >= 5120 else 1
return {
"block_m": block_m,
"block_n": block_n,
"block_k": block_k,
"split_k": split_k,
"num_warps": num_warps,
"num_ctas": 1,
"num_stages": num_stages,
"group_size_m": 8,
"max_nreg": None,
}
return None
def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
if values := _LORA_A_PTR_DICT.get(key):
return values
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
tensor_ptrs = []
for lora_a_weight in lora_a_weights:
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_a_weight.size(1) == 1
lora_a_weight = lora_a_weight.squeeze(dim=1)
else:
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_a_weight.is_contiguous()
tensor_ptrs.append(lora_a_weight.data_ptr())
lora_strides_d0.append(lora_a_weight.stride(0))
lora_strides_d1.append(lora_a_weight.stride(1))
lora_strides_d2.append(lora_a_weight.stride(2))
if len(lora_a_weights) > 1:
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
else:
lora_ptr_tensor = lora_a_weights[0]
if (
len(set(lora_strides_d0)) > 1
or len(set(lora_strides_d1)) > 1
or len(set(lora_strides_d2)) > 1
):
raise ValueError("All LoRA weights must have the same stride.")
_LORA_A_PTR_DICT[key] = (
lora_ptr_tensor,
lora_strides_d0[0],
lora_strides_d1[0],
lora_strides_d2[0],
)
return _LORA_A_PTR_DICT.get(key)
def _get_lora_b_ptr(
lora_weights: list[torch.Tensor], offset_start: int, device: torch.device
):
"""
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if values := _LORA_B_PTR_DICT.get(key):
return values
slice_offset_lst = []
tensor_ptrs = []
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
hidden_sizes = []
slice_offset = offset_start
for lora_b_weight in lora_weights:
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weight.size(1) == 1
lora_b_weight = lora_b_weight.squeeze(dim=1)
else:
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weight.is_contiguous()
tensor_ptrs.append(lora_b_weight.data_ptr())
lora_strides_d0.append(lora_b_weight.stride(0))
lora_strides_d1.append(lora_b_weight.stride(1))
lora_strides_d2.append(lora_b_weight.stride(2))
slice_offset_lst.append(slice_offset)
slice_offset += lora_b_weight.size(1)
hidden_sizes.append(lora_b_weight.size(1))
if len(lora_weights) > 1:
# note these are device tensors
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
slice_start_tensor = torch.tensor(
slice_offset_lst, device=device, dtype=torch.uint64
)
else:
slice_start_tensor = slice_offset_lst[0]
lora_ptr_tensor = lora_b_weight[0]
# If each lora has the same stride, there's no need to use a
# tensor for storage.
if (
len(set(lora_strides_d0)) == 1
and len(set(lora_strides_d1)) == 1
and len(set(lora_strides_d2)) == 1
) and len(set(hidden_sizes)) == 1:
lora_strides_d0_tensor = lora_strides_d0[0]
lora_strides_d1_tensor = lora_strides_d1[0]
lora_strides_d2_tensor = lora_strides_d2[0]
hidden_sizes_tensor = hidden_sizes[0]
same_stride = True
else:
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
same_stride = False
# MAX_N is the maximum hidden size among all the lora_b weights
MAX_N = max(hidden_sizes)
_LORA_B_PTR_DICT[key] = (
slice_start_tensor,
lora_ptr_tensor,
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
hidden_sizes_tensor,
same_stride,
MAX_N,
)
return _LORA_B_PTR_DICT.get(key)
@functools.lru_cache
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
if user_defined_config_folder is not None:
gpu_name = torch.cuda.get_device_name()
gpu_name = gpu_name.replace(" ", "_")
gpu_name = gpu_name.replace("-", "_")
config_fname = None
# only expand op needs to consider add_inputs
if op_type == "expand":
config_fname = (
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
)
else:
config_fname = f"{gpu_name}_{op_type.upper()}.json"
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
if not config_path.exists():
logger.warning_once(f"No LoRA kernel configs founded in {config_path}")
return None
# Load json
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
with open(str(config_path)) as f:
config_data = json.load(f)
else:
config_data = None
return config_data
@functools.lru_cache
def get_lora_op_configs(
op_type: str,
max_loras: int,
batch: int,
hidden_size: int,
rank: int,
num_slices: int,
add_inputs: bool | None = None,
moe_intermediate_size: int | None = None,
) -> dict[str, int | None]:
# Add support for fused_moe_lora ops
assert op_type in [
"shrink",
"expand",
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_shrink",
"fused_moe_lora_w2_expand",
]
# default config
default = {}
if op_type == "shrink":
default = {
"block_m": 32,
"block_n": 16,
"block_k": 256 if batch < 128 else 32,
"split_k": 64 if batch < 128 else 8,
"num_warps": 4,
"num_ctas": 1,
"group_size_m": 8,
"num_stages": 2,
"max_nreg": None,
}
# The default config for fused_moe_lora ops
elif op_type in [
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_shrink",
"fused_moe_lora_w2_expand",
]:
default = {
"block_m": 64,
"block_n": 64,
"block_k": 32,
"num_warps": 4,
"num_stages": 3,
"group_size_m": 8,
"split_k": 1,
}
else:
default = {
"block_m": 64,
"block_n": 128,
"block_k": 16,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2,
"max_nreg": None,
}
if _is_corex_backend():
corex_default = _get_corex_lora_op_config(
op_type=op_type,
batch=batch,
hidden_size=hidden_size,
num_slices=num_slices,
)
if corex_default is not None:
default = corex_default
m = batch
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
config_data: Any
config_data = load_lora_op_config(op_type, add_inputs)
if not config_data:
logger.warning_once("Using default LoRA kernel configs")
return default
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
# slice by max_loras
config_data = (
config_data.get(str(max_loras))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
)
# slice by num_slices
config_data = config_data[str(num_slices)]
# slice by m
config_data = (
config_data.get(str(m))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
)
# slice by k
config_data = (
config_data.get(str(k))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
)
# slice by n
config_data = (
config_data.get(str(n))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
)
# slice by moe-intermediate-size if applicable
if moe_intermediate_size is not None:
i = moe_intermediate_size
config_data = (
config_data.get(str(i))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))]
)
assert config_data is not None
return config_data
@lru_cache
def supports_pdl(device: torch.device | None = None) -> bool:
"""
Refer to: https://github.com/triton-lang/triton/blob/v3.5.0/python/tutorials/11-programmatic-dependent-launch.py
"""
# PDL requires compute capability SM90 or above
return current_platform.is_cuda() and current_platform.has_device_capability(90)

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

Binary file not shown.

Binary file not shown.

View File

@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
import torch_xla.core.xla_builder as xb
from torch.library import impl
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
@jax.jit
def bgmv_jax(inputs, loras, idxs):
return jnp.einsum(
"td,tX,Xld->tl",
inputs,
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
loras,
)
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
if output_tensor.shape[1] > outputs.shape[1]:
outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
if add_inputs:
return output_tensor + outputs[:limit, : output_tensor.shape[1]]
else:
return outputs[:limit, : output_tensor.shape[1]]
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
scaling (float, optional): Scalar multiplier applied to the output.
"""
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
outputs = F.pad(
outputs,
(
slice_offset,
output_tensor.shape[1] - (slice_offset + slice_size),
0,
0,
),
)
if add_inputs:
return output_tensor + outputs
else:
return outputs

128
lora/peft_helper.py Normal file
View File

@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
import json
import math
import os
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
logger = init_logger(__name__)
@dataclass
class PEFTHelper:
"""
A helper class for PEFT configurations, specifically designed for LoRA.
This class handles configuration validation, compatibility checks for
various LoRA implementations.
"""
# Required fields
r: int
lora_alpha: int
target_modules: list[str] | str
bias: Literal["none"] = field(default="none")
modules_to_save: list[str] | None = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: int | None = field(default=False)
def _validate_features(self) -> list[str]:
"""
Check if there are any unsupported LoRA features.
"""
error_msg = []
if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")
return error_msg
def __post_init__(self):
if self.use_rslora:
logger.info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
@classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
# Get all field information from the class
class_fields = {f.name: f for f in fields(cls)}
# Check for required fields
required_fields = {
name
for name, f in class_fields.items()
if f.default is MISSING and f.default_factory is MISSING
}
# Identify any missing required fields
missing_fields = required_fields - set(config_dict.keys())
if missing_fields:
raise ValueError(f"Missing required configuration fields: {missing_fields}")
# Filter out fields that aren't defined in the class
filtered_dict = {k: v for k, v in config_dict.items() if k in class_fields}
return cls(**filtered_dict)
@classmethod
def from_local_dir(
cls,
lora_path: str,
max_position_embeddings: int | None,
tensorizer_config_dict: dict | None = None,
) -> "PEFTHelper":
lora_config_path = os.path.join(lora_path, "adapter_config.json")
if tensorizer_config_dict:
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
from tensorizer.stream_io import open_stream
lora_config_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_config.json"
)
with open_stream(
lora_config_path, mode="rb", **tensorizer_args.stream_kwargs
) as f:
config = json.load(f)
logger.info(
"Successfully deserialized LoRA config from %s",
tensorizer_config.tensorizer_dir,
)
else:
with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
return cls.from_dict(config)
def validate_legal(self, lora_config: LoRAConfig) -> None:
"""
Validates the LoRA configuration settings against application
constraints and requirements.
"""
error_msg = self._validate_features()
if self.r > lora_config.max_lora_rank:
error_msg.append(
f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}."
)
if self.bias != "none":
error_msg.append("Adapter bias is not supported.")
if error_msg:
raise ValueError(f"{' '.join(error_msg)}")

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
__all__ = [
"PunicaWrapperBase",
"get_punica_wrapper",
]

Binary file not shown.

View File

@@ -0,0 +1,492 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import torch
from .utils import compute_meta, convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
class PunicaWrapperABC(ABC):
"""
PunicaWrapper ABC.
"""
@abstractmethod
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
) -> None:
"""
Update the lora-related metadata
"""
raise NotImplementedError
@abstractmethod
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
"""
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_b.
"""
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
and this layer only requires the expand operation.
"""
raise NotImplementedError
@abstractmethod
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applicable to linear-related lora.
"""
raise NotImplementedError
@abstractmethod
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
"""
raise NotImplementedError
class PunicaWrapperBase(PunicaWrapperABC):
"""
PunicaWrapperBase is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
self._token_lora_indices = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._sampler_indices = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._sampler_indices_padded = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._embeddings_indices = torch.empty(
2, max_num_batched_tokens, dtype=torch.long, device=device
)
# 4 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len: list[int | None] = [None] * 4
# these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device)
self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device)
self._lora_indices_per_batch = torch.empty(
max_batches, dtype=torch.long, device=device
)
self.device: torch.device = device
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
):
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
self.device,
)
self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded
)
self._embeddings_indices[
: embeddings_indices.shape[0], : embeddings_indices.shape[1]
].copy_(embeddings_indices)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
(
b_seq_start_tensor,
seq_length_tensor,
lora_indices_tensor,
batch_size,
max_length,
token_nums,
no_lora,
) = compute_meta(token_lora_tensor)
self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor)
self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor)
self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_(
lora_indices_tensor
)
self.batch_size = batch_size
self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora
@property
def prefill_metadata(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return (
self._seq_start_locs[: self.batch_size],
self._seq_lengths[: self.batch_size],
self._lora_indices_per_batch[: self.batch_size],
self.batch_size,
self.max_length,
self.token_nums,
)
@property
def token_lora_indices(self) -> torch.Tensor:
"""
This property provides the lora indices corresponding to each token
in the batch. An index of -1 means no lora should be applied.
"""
token_lora_len = self.indices_len[0]
return self._token_lora_indices[:token_lora_len]
@property
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA.
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
):
self._update_base_metadata(
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metadata(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
@abstractmethod
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
offset = offset_start
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
shrink_config,
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
):
"""
Performs a fused forward computation for LoRA of
Mixture-of-Experts (MoE) layer.
"""
# TODO: implement it based on torch ops
raise NotImplementedError

View File

@@ -0,0 +1,351 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.lora.ops.torch_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
from .punica_base import PunicaWrapperBase
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperCPU(PunicaWrapperBase):
"""
PunicaWrapperCPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
# No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
# No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
# No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
bgmv_expand_slice(
x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs
)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (
self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (
self._shrink_prefill if self.is_prefill else self._shrink_decode
)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (
self._expand_prefill if self.is_prefill else self._expand_decode
)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = tuple(
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices))
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, self.sampler_indices, add_inputs=True)
y = y.view_as(y_org)

View File

@@ -0,0 +1,422 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import final
import torch
from vllm.lora.layers import LoRAMapping
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import round_up
if HAS_TRITON:
from vllm.lora.ops.triton_ops import (
LoRAKernelMeta,
fused_moe_lora,
lora_expand,
lora_shrink,
)
from vllm import _custom_ops as ops
from .punica_base import PunicaWrapperBase
@final
class PunicaWrapperGPU(PunicaWrapperBase):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
self.max_loras = kwargs["max_loras"]
self.token_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device
)
# When speculative decoding is enabled, max_num_samples is
# max_batches * (num_speculative_decoding_tokens + 1).
# This line can be optimized by replacing max_num_batched_tokens
# to max_batches * (num_speculative_decoding_tokens + 1).
self.prompt_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device
)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
def add_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
lora_shrink(
x,
lora_a_stacked,
y,
*self.token_mapping_meta.meta_args(x.size(0)),
scale,
)
def add_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
assert x.ndim == 3
assert x.size(0) == len(output_slices)
num_tokens = x.size(1) # first dimension is the num slices
lora_expand(
x,
lora_b_stacked,
y,
*self.token_mapping_meta.meta_args(num_tokens),
offset_start=offset_start,
add_inputs=True,
)
y = y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
lora_expand(
x.unsqueeze(dim=0),
(lora_b_stacked,),
y,
*self.token_mapping_meta.meta_args(x.size(0)),
offset_start=0,
add_inputs=add_inputs,
)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
import vllm.envs as env
if env.VLLM_USE_LORA_FUSION:
import ixformer.inference.functions as ops
num_token, m = x.size(0), x.size(-1)
k, n = lora_b_stacked[0].size(-1), y.size(-1)
if len(lora_a_stacked) == 1 and ops.lora_gemv_optim_condition(num_token, m, k, n):
ops.add_lora_linear(y, x, lora_a_stacked, lora_b_stacked,
lora_bias_stacked = None, scale = 1.0, output_slices = (1,))
return
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty(
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
)
self.add_shrink(
buffer, # type: ignore
x,
lora_a_stacked,
scale,
**kwargs,
)
self.add_expand(
y,
buffer, # type: ignore
lora_b_stacked,
output_slices,
add_inputs=True,
**kwargs,
)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]): Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
lora_shrink(
x,
[lora_a_stacked],
buffer.unsqueeze(dim=0),
*self.prompt_mapping_meta.meta_args(x.size(0)),
scale,
)
lora_expand(
buffer.unsqueeze(dim=0),
[lora_b_stacked],
y,
*self.prompt_mapping_meta.meta_args(buffer.size(0)),
add_inputs=True,
)
y = y.view_as(y_org)
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
(token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(
num_tokens
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return sorted_ids, expert_ids, num_tokens_post_pad
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
shrink_config,
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0))
fused_moe_lora(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
max_lora_rank,
top_k_num,
lora_ids,
adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64),
shrink_config.get("BLOCK_SIZE_K", 32),
shrink_config.get("GROUP_SIZE_M", 8),
shrink_config.get("NUM_WARPS", 4),
shrink_config.get("NUM_STAGES", 3),
shrink_config.get("SPLIT_K", 1),
expand_config.get("BLOCK_SIZE_M", 64),
expand_config.get("BLOCK_SIZE_N", 64),
expand_config.get("BLOCK_SIZE_K", 32),
expand_config.get("GROUP_SIZE_M", 8),
expand_config.get("NUM_WARPS", 4),
expand_config.get("NUM_STAGES", 3),
expand_config.get("SPLIT_K", 1),
mul_routed_weight,
)

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from .punica_base import PunicaWrapperBase
logger = init_logger(__name__)
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
punica_wrapper_qualname = current_platform.get_punica_wrapper()
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
assert punica_wrapper is not None, (
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
)
logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1])
return punica_wrapper

View File

@@ -0,0 +1,359 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
import torch_xla
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from .punica_base import PunicaWrapperBase
class PunicaWrapperTPU(PunicaWrapperBase):
"""
PunicaWrapperTPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
# isn't supported by the TPU. So convert those tensors to int32.
# Not all of them are used by the TPU so only convert the useful ones.
self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32)
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
self._sampler_indices_padded = self._sampler_indices_padded.to(
dtype=torch.int32
)
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
return self._embeddings_indices[:]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def shrink(
self,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
def expand(
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool
):
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs)
def expand_slice(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
) -> torch.Tensor:
return bgmv_expand_slice(
x,
w_t_all,
y,
self._get_token_lora_indices(x),
y_offset,
y_slice_size,
add_inputs,
)
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
lora_s = lora_a_stacked[slice_idx]
y_s = self.shrink(x, lora_s, scale)
y[slice_idx, :, :] = y_s # type: ignore[index]
return y
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
return y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only needs the expand op
return self.expand(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
T = x.size(0)
buffer = torch.zeros(
(len(output_slices), T, r),
dtype=x.dtype,
device=x.device,
)
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
return self.add_expand(
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org)
# This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
):
# Make sure we don't accidentally collect outside operations
torch_xla.sync()
# Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we
# avoid having backend specific LoRAMapping classes?
mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping)
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
"cpu",
)
self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape, dims=1
).to(self.device)
self._sampler_indices = self._pad_to_shape(
sampler_indices, self._sampler_indices.shape, dims=1
).to(self.device)
self._sampler_indices_padded = self._pad_to_shape(
sampler_indices_padded, self._sampler_indices_padded.shape, dims=1
).to(self.device)
self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape, dims=2
).to(self.device)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
self.batch_size = 1
self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[
: self.batch_size
]
def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
num_reqs = len(prompt_mapping)
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
# import
MIN_NUM_SEQS = 8
padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
pad_len = padded_num_reqs - num_reqs
padding = [-1] * pad_len
return tuple(list(prompt_mapping) + padding)
def _pad_to_shape(self, src, target_shape, dims=1):
if dims == 1:
pad_len = target_shape[0] - src.shape[0]
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
else:
pad_rows = target_shape[0] - src.shape[0]
pad_cols = target_shape[1] - src.shape[1]
return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32)

View File

@@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import final
import torch
from vllm.lora.layers import LoRAMapping
from vllm.lora.ops.ipex_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from .punica_base import PunicaWrapperBase
@final
class PunicaWrapperXPU(PunicaWrapperBase):
"""
PunicaWrapperXPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica ipex kernel.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(
mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size
)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
def _apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), scale)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
token_lora_indices = self._get_token_lora_indices(x)
bgmv_expand_slice(
x, w_t_all, y, token_lora_indices, y_offset, y_slice_size, add_inputs
)
def add_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
def add_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
assert x.ndim == 3
assert x.size(0) == len(output_slices)
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_start,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_start += output_slices[slice_idx]
y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
token_lora_indices = self._get_token_lora_indices(x)
bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r),
dtype=torch.float32,
device=x.device,
)
self.add_shrink(
buffer, # type: ignore
x,
lora_a_stacked,
scale,
**kwargs,
)
self.add_expand(
y,
buffer, # type: ignore
lora_b_stacked,
output_slices,
add_inputs=True,
**kwargs,
)
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]): Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org)

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
def compute_meta(
token_lora_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True
)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (
b_seq_start_tensor,
seq_length_tensor,
lora_indices_tensor,
batch_size,
max_length,
token_nums,
no_lora,
)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indices. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indices, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices).
"""
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
prompt_mapping: list[int] = [
lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (
lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0
else -1
)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
indices_list: list[list[int] | torch.Tensor] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(
prompt_mapping, dtype=torch.long, device=device
)
embeddings_indices = torch.stack(
[
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
]
)
embeddings_indices = torch.where(
embeddings_indices == -1, max_loras - 1, embeddings_indices
)
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded = torch.where(
sampler_indices_padded == -1, max_loras - 1, sampler_indices_padded
)
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long
) + (sampler_indices_padded * len(sampler_indices_padded))
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
)

100
lora/request.py Normal file
View File

@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
import msgspec
class LoRARequest(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True,
): # type: ignore[call-arg]
"""
Request for a LoRA adapter.
Note that this class should be used internally. For online
serving, it is recommended to not allow users to use this class but
instead provide another layer of abstraction to prevent users from
accessing unauthorized LoRA adapters.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
"""
lora_name: str
lora_int_id: int
lora_path: str = ""
lora_local_path: str | None = msgspec.field(default=None)
long_lora_max_len: int | None = None
base_model_name: str | None = msgspec.field(default=None)
tensorizer_config_dict: dict | None = None
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(f"id must be > 0, got {self.lora_int_id}")
if self.lora_local_path:
warnings.warn(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'lora_path' instead.",
DeprecationWarning,
stacklevel=2,
)
if not self.lora_path:
self.lora_path = self.lora_local_path or ""
# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"
@property
def adapter_id(self):
return self.lora_int_id
@property
def name(self):
return self.lora_name
@property
def path(self):
return self.lora_path
@property
def local_path(self):
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2,
)
return self.lora_path
@local_path.setter
def local_path(self, value):
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2,
)
self.lora_path = value
def __eq__(self, value: object) -> bool:
"""
Overrides the equality method to compare LoRARequest
instances based on lora_name. This allows for identification
and comparison lora adapter across engines.
"""
return isinstance(value, self.__class__) and self.lora_name == value.lora_name
def __hash__(self) -> int:
"""
Overrides the hash method to hash LoRARequest instances
based on lora_name. This ensures that LoRARequest instances
can be used in hash-based collections such as sets and dictionaries,
identified by their names across engines.
"""
return hash(self.lora_name)

88
lora/resolver.py Normal file
View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
class LoRAResolver(ABC):
"""Base class for LoRA adapter resolvers.
This class defines the interface for resolving and fetching LoRA adapters.
Implementations of this class should handle the logic for locating and
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
etc.).
"""
@abstractmethod
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
"""Abstract method to resolve and fetch a LoRA model adapter.
Implements logic to locate and download LoRA adapter based on the name.
Implementations might fetch from a blob storage or other sources.
Args:
base_model_name: The name/identifier of the base model to resolve.
lora_name: The name/identifier of the LoRA model to resolve.
Returns:
Optional[LoRARequest]: The resolved LoRA model information, or None
if the LoRA model cannot be found.
"""
pass
@dataclass
class _LoRAResolverRegistry:
resolvers: dict[str, LoRAResolver] = field(default_factory=dict)
def get_supported_resolvers(self) -> Set[str]:
"""Get all registered resolver names."""
return self.resolvers.keys()
def register_resolver(
self,
resolver_name: str,
resolver: LoRAResolver,
) -> None:
"""Register a LoRA resolver.
Args:
resolver_name: Name to register the resolver under.
resolver: The LoRA resolver instance to register.
"""
if resolver_name in self.resolvers:
logger.warning(
"LoRA resolver %s is already registered, and will be "
"overwritten by the new resolver instance %s.",
resolver_name,
resolver,
)
self.resolvers[resolver_name] = resolver
def get_resolver(self, resolver_name: str) -> LoRAResolver:
"""Get a registered resolver instance by name.
Args:
resolver_name: Name of the resolver to get.
Returns:
The resolver instance.
Raises:
KeyError: If the resolver is not found in the registry.
"""
if resolver_name not in self.resolvers:
raise KeyError(
f"LoRA resolver '{resolver_name}' not found. "
f"Available resolvers: {list(self.resolvers.keys())}"
)
return self.resolvers[resolver_name]
LoRAResolverRegistry = _LoRAResolverRegistry()

293
lora/utils.py Normal file
View File

@@ -0,0 +1,293 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import TYPE_CHECKING, Optional
import huggingface_hub
import regex as re
from huggingface_hub.utils import (
EntryNotFoundError,
HfHubHTTPError,
HFValidationError,
RepositoryNotFoundError,
)
from torch import nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
# being imported for _all_lora_classes below
from vllm.lora.layers import (
BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
if TYPE_CHECKING:
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
}
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
return True
return False
def from_layer(
layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
if lora_cls.can_replace_layer(
source_layer=layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
):
instance_layer = lora_cls(layer)
instance_layer.create_lora_weights(max_loras, lora_config, model_config)
return instance_layer
return layer
def from_layer_logits_processor(
layer: "LogitsProcessor",
lm_head: "ParallelLMHead",
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(
layer,
lm_head.embedding_dim,
lm_head.weight.dtype,
lm_head.weight.device,
lm_head.get_sharded_to_full_mapping(),
)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
def replace_submodule(
model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def parse_fine_tuned_lora_name(
name: str, weights_mapper: Optional["WeightsMapper"] = None
) -> tuple[str, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
"""
# LoRA weight qualified name usually starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if name.startswith("base_model.model."):
name = name.replace("base_model.model.", "")
name = weights_mapper._map_name(name) if weights_mapper else name
# recover the prefix `base_model.model.`
name = "base_model.model." + name
else:
name = weights_mapper._map_name(name) if weights_mapper else name
# In some situations, we may not start with `base_model.model.`.
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
# we should keep the prefix intact.
start_index = 2 if name.startswith("base_model.model.") else 0
parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
new_name = ".".join(parts[start_index:-2])
return new_name, parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[start_index:-1])
return new_name, parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported LoRA weight")
def is_regex_target_modules(
load_modules: str | list[str], expected_lora_modules: list[str]
) -> bool:
"""
PEFT supports passing `target_modules` in the form of regular expressions,
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
determine whether the suffix in the regular expression is present in the
`expected_lora_modules`.
"""
def is_valid_regex(pattern):
try:
re.compile(pattern)
return True
except re.error:
return False
def is_subset(sub_list, full_list):
return set(sub_list).issubset(set(full_list))
# Similar to PEFT's processing logic, regex-related operations are only
# executed when the load_modules is a `str`.
if not isinstance(load_modules, str):
return False
if is_valid_regex(load_modules):
match = re.search(r"\((.*?)\)\$?$", load_modules)
if match:
suffix = match.group(1).split("|")
return is_subset(suffix, expected_lora_modules)
return False
def get_supported_lora_modules(model: nn.Module) -> list[str]:
"""
In vLLM, all linear layers support LoRA.
"""
supported_lora_modules: set[str] = set()
for name, module in model.named_modules():
# get the embedding modules if the module's embedding_modules
# is not empty.
embedding_modules = getattr(module, "embedding_modules", None)
if embedding_modules is not None:
for name in embedding_modules:
supported_lora_modules.add(name)
# get all the linear subfixes.
if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1])
if isinstance(module, (FusedMoE,)):
supported_lora_modules.add(name.split(".")[-1])
return list(supported_lora_modules)
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.
Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.
Returns:
str: The resolved absolute local path to the lora model.
"""
# Check if the path is an absolute path. Return it no matter exists or not.
if os.path.isabs(lora_path):
return lora_path
# If the path starts with ~, expand the user home directory.
if lora_path.startswith("~"):
return os.path.expanduser(lora_path)
# Check if the expanded relative path exists locally.
if os.path.exists(lora_path):
return os.path.abspath(lora_path)
# If the path does not exist locally, assume it's a Hugging Face repo.
try:
local_snapshot_path = huggingface_hub.snapshot_download(repo_id=lora_path)
except (
HfHubHTTPError,
RepositoryNotFoundError,
EntryNotFoundError,
HFValidationError,
):
# Handle errors that may occur during the download
# Return original path instead of throwing error here
logger.exception("Error downloading the HuggingFace model")
return lora_path
return local_snapshot_path
def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
if is_moe_model(model):
if moe_packed_mapping := get_moe_expert_mapping(model):
# This method generates and returns a dictionary mapping packed module
# names to lists of their corresponding submodule names. It includes
# both static mappings and dynamic mappings for expert layers, where
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
packed_modules_mapping["experts"] = [
weight_name.rstrip(".") for _, weight_name, _, _ in moe_packed_mapping
]
return packed_modules_mapping
else:
raise AttributeError(
"To support LoRA for MoE model, "
"'get_expert_mapping' must be implemented"
)
else:
return get_packed_modules_mapping(model)

279
lora/worker_manager.py Normal file
View File

@@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any, Literal
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.models import (
LoRAModel,
LoRAModelManager,
LRUCacheLoRAModelManager,
create_lora_manager,
)
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
logger = init_logger(__name__)
class WorkerLoRAManager:
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_manager_cls: type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
embedding_modules: dict[str, str],
embedding_padding_modules: list[str],
lora_model_cls: type[LoRAModel] = LoRAModel,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self.embedding_padding_modules = embedding_padding_modules
self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.lora_config = vllm_config.lora_config
# Use get_text_config() in case of multimodal models
text_config = vllm_config.model_config.hf_config.get_text_config()
self.max_position_embeddings = text_config.max_position_embeddings
self.device = device
# Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@property
def is_enabled(self) -> bool:
return True
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
lora_manager = create_lora_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
lora_manager_cls=self._manager_cls,
)
self._adapter_manager = lora_manager
return lora_manager.model
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
try:
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_modules: list[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_modules.extend(packed_modules_mapping[module])
else:
expected_lora_modules.append(module)
if module == "experts":
expected_lora_modules.append(module)
expected_lora_modules = list(set(expected_lora_modules))
lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(
lora_path,
self.max_position_embeddings,
lora_request.tensorizer_config_dict,
)
# Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails.
peft_helper.validate_legal(self.lora_config)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
target_embedding_padding=self.vocab_size
+ self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
weights_mapper=hf_to_vllm_mapper,
)
except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
# offline mode)
# - No local adapter files found at `lora_request.lora_path`
# For NotFoundError
raise ValueError(
f"Loading lora {lora_request.lora_name} failed: No adapter "
f"found for {lora_request.lora_path}"
) from e
except Exception as e:
# For BadRequestError
raise e
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
raise ValueError(
f"LoRA added vocab size {lora.extra_vocab_size} "
f"is greater than lora_extra_vocab_size "
f"{self.lora_config.lora_extra_vocab_size}."
)
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_adapters():
return False
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
else:
dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules
)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._adapter_manager.add_adapter(dummy_lora)
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
self._apply_adapters(requests)
if mapping is not None:
self._adapter_manager.set_adapter_mapping(mapping)
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
existing_adapters = self.list_adapters()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests
if adapter_request
}
if len(models_map) > self._adapter_manager.adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
"than the number of GPU model slots "
f"({self._adapter_manager.adapter_slots})."
)
requested_ids = set(models_map)
for adapter_id in existing_adapters - requested_ids:
self.remove_adapter(adapter_id)
for adapter_id in requested_ids - existing_adapters:
self.add_adapter(models_map[adapter_id])
def add_adapter(self, adapter_request: Any) -> bool:
if adapter_request.adapter_id in self.list_adapters():
return False
loaded_adapter = self._load_adapter(adapter_request)
loaded = self._adapter_manager.add_adapter(loaded_adapter)
self._adapter_manager.activate_adapter(loaded_adapter.id)
return loaded
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> set[int]:
return set(self._adapter_manager.list_adapters())
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
model: torch.nn.Module,
) -> Any:
lora_manager = create_lora_manager(
model,
lora_manager_cls=self._manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
max_num_batched_tokens=self.max_num_batched_tokens,
)
self._adapter_manager = lora_manager
return lora_manager.model
def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests
if lora_request
}
if len(loras_map) > self._adapter_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._adapter_manager.lora_slots})."
)
for lora in loras_map.values():
self.add_adapter(lora)
def add_adapter(self, lora_request: LoRARequest) -> bool:
# Note that this method is not thread-safe. It may be invoked multiple
# times for the same adapter when using multiple API servers.
# This is ok because it's currently only called from
# the single-threaded core engine loop.
if lora_request.lora_int_id not in self.list_adapters():
# Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters.
# This may cause the # of loaded lora adapters to very temporarily
# exceed `--max-cpu-loras`.
lora = self._load_adapter(lora_request)
# Loading succeeded, now check if we will exceed cache capacity and
# evict if the oldest adapter if so
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
self._adapter_manager.remove_oldest_adapter()
# Then add the new adapter to the cache
loaded = self._adapter_manager.add_adapter(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = (
self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
)
self._adapter_manager.activate_adapter(lora_request.lora_int_id)
return loaded