Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

0
vllm/lora/__init__.py Normal file
View File

View File

@@ -0,0 +1,45 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.layers.column_parallel_linear import (
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
from vllm.lora.layers.row_parallel_linear import (
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
)
from vllm.lora.layers.utils import LoRAMapping, LoRAMappingType
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
__all__ = [
"BaseLayerWithLoRA",
"VocabParallelEmbeddingWithLoRA",
"LogitsProcessorWithLoRA",
"ColumnParallelLinearWithLoRA",
"ColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearWithLoRA",
"MergedColumnParallelLinearWithShardedLoRA",
"MergedColumnParallelLinearVariableSliceWithLoRA",
"MergedQKVParallelLinearWithLoRA",
"MergedQKVParallelLinearWithShardedLoRA",
"QKVParallelLinearWithLoRA",
"QKVParallelLinearWithShardedLoRA",
"RowParallelLinearWithLoRA",
"RowParallelLinearWithShardedLoRA",
"ReplicatedLinearWithLoRA",
"LoRAMapping",
"LoRAMappingType",
"FusedMoEWithLoRA",
"FusedMoE3DWithLoRA",
]

66
vllm/lora/layers/base.py Normal file
View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
if TYPE_CHECKING:
from vllm.lora.punica_wrapper import PunicaWrapperBase
class BaseLayerWithLoRA(nn.Module):
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora a if splitting for tensor parallelism."""
...
def slice_lora_b(
self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora b if splitting with tensor parallelism."""
...
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
...
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
...
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
...
def set_mapping(
self,
punica_wrapper,
):
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
raise NotImplementedError

View File

@@ -0,0 +1,169 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
LinearBase,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
from .utils import _get_lora_device
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LinearBase):
super().__init__()
self.base_layer = base_layer
self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer)
self.output_slices: tuple[int, ...]
self.output_size: int
self.n_slices: int
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
self.lora_config = lora_config
#
if isinstance(self.base_layer, ReplicatedLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, ColumnParallelLinear):
lora_a_out_size = (
lora_config.max_lora_rank
if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size)
)
lora_b_out_size = self.output_size
elif isinstance(self.base_layer, RowParallelLinear):
lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = (
self.output_size
if not lora_config.fully_sharded_loras
else divide(self.output_size, self.tp_size)
)
else:
raise NotImplementedError
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_a_out_size,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_b_out_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.output_slices = (self.lora_b_stacked[0].shape[2],)
def reset_lora(self, index: int):
for s_index in range(self.n_slices):
self.lora_a_stacked[s_index][index] = 0
self.lora_b_stacked[s_index][index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
# Except for QKVParallelLinearWithLoRA and
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
# store weights in a tuple of size 1. These two layers will
# override this function.
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
assert (
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
)
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
)
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
original_shape = output.shape if output.ndim == 3 else None
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
)
if not current_platform.can_update_inplace():
output = lora_output
# Reshape the flattened output back to its original shape,
# as some MM encoders cannot handle flattened inputs.
if original_shape is not None:
output = output.reshape(original_shape)
return output
@property
def weight(self) -> torch.Tensor:
# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")
@property
def bias(self) -> torch.Tensor | None:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None

View File

@@ -0,0 +1,658 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import tensor_model_parallel_all_gather
from vllm.distributed.utils import divide
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
)
from vllm.platforms import current_platform
from .base_linear import BaseLinearLayerWithLoRA
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
"""
For `ColumnParallelLinearWithLoRA` or classes that inherit from
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
"""
assert (
layer.n_slices
== len(layer.lora_a_stacked)
== len(layer.lora_b_stacked)
== len(layer.output_slices)
)
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
# Since communication is needed, the buffer is directly initialized as a
# tensor rather than a tuple of tensor.
buffers = torch.zeros(
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink(
buffers, x, layer.lora_a_stacked, 1.0
)
if not current_platform.can_update_inplace():
buffers = shrunk_buffers
buffers = tensor_model_parallel_all_gather(buffers)
lora_output: torch.Tensor | None = layer.punica_wrapper.add_expand(
output,
buffers,
layer.lora_b_stacked,
layer.output_slices,
offset_start=0,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
# now have column partitioned and packed output
return output
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
There are two types for the `base_layer`:
1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
"""
def __init__(self, base_layer: ColumnParallelLinear) -> None:
super().__init__(base_layer)
# The base_layer type is ColumnParallelLinear or
# MergedColumnParallelLinear, their weight sharding logic is
# inconsistent when TP is greater than 1.
self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear
self.output_size = self.base_layer.output_size_per_partition
# There is only one LoRA layer
self.n_slices = 1
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
# Applicable to cases where the base_layer is
# MergedColumnParallelLinear.
if self.is_merged_col_linear:
shard_size = self.output_size // 2
offset = lora_b.shape[0] // 2
left_weight = lora_b[
self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
]
right_weight = lora_b[
offset + self.tp_rank * shard_size : offset
+ (self.tp_rank + 1) * shard_size,
:,
]
lora_b = torch.cat([left_weight, right_weight], dim=0)
# Applicable to cases where the base_layer is
# ColumnParallelLinear.
else:
shard_size = self.output_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :]
return lora_b
def forward(
self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
"""Forward of ColumnParallelLinear
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
# Matrix multiply.
output_parallel = self.apply(input_, bias)
if self.base_layer.gather_output and self.tp_size > 1:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
if not self.base_layer.return_bias:
return output
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
if type(source_layer) is ColumnParallelLinear:
return True
if type(source_layer) is MergedColumnParallelLinear:
if len(packed_modules_list) != 1:
return False
# Exclude layers with 3+ output sizes - those are handled by
# MergedColumnParallelLinearVariableSliceWithLoRA since this
# class's slice_lora_b assumes exactly 2 slices.
return not (
hasattr(source_layer, "output_sizes")
and len(source_layer.output_sizes) >= 3
)
return False
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
packed together (e.g. gate_proj + up_proj -> gate_up_proj).
This means we have 2 LoRAs, each applied to one half of the layer.
Both slices must have the same size.
"""
def __init__(
self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
) -> None:
super().__init__(base_layer)
# There are two LoRA layers
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
# we need to divide it by the tp_size to get correct slices size
output_sizes = self.base_layer.output_sizes
self.output_slices = tuple(
divide(output_size, self.tp_size) for output_size in output_sizes
)
self.n_slices = len(self.output_slices)
self.output_ids = (self.tp_rank,) * self.n_slices
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""
The main reason for overriding this function is to enhance code
maintainability.
"""
self.lora_config = lora_config
lora_a_output_size_per_partition = (
lora_config.max_lora_rank
if not lora_config.fully_sharded_loras
else divide(lora_config.max_lora_rank, self.tp_size)
)
self.lora_a_stacked = tuple(
torch.zeros(
max_loras,
1,
lora_a_output_size_per_partition,
self.input_size,
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self.n_slices)
)
self.lora_b_stacked = tuple(
torch.zeros(
max_loras,
1,
output_size,
lora_config.max_lora_rank,
dtype=lora_config.lora_dtype,
device=self.device,
)
for output_size in self.output_slices
)
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
return lora_a
def slice_lora_b(
self, lora_b: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
sliced_lora_b = [None] * self.n_slices
for i, (shard_id, shard_size) in enumerate(
zip(self.output_ids, self.output_slices)
):
if (lora_b_i := lora_b[i]) is not None:
sliced_lora_b[i] = lora_b_i[
shard_size * shard_id : shard_size * (shard_id + 1), :
]
return sliced_lora_b
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
self.reset_lora(index)
if self.tp_size > 1:
lora_a = self.slice_lora_a(lora_a)
lora_b = self.slice_lora_b(lora_b)
for i in range(self.n_slices):
if (lora_a_i := lora_a[i]) is not None:
self.lora_a_stacked[i][
index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
].copy_(lora_a_i, non_blocking=True)
if (lora_b_i := lora_b[i]) is not None:
self.lora_b_stacked[i][
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
].copy_(lora_b_i, non_blocking=True)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return (
type(source_layer) is MergedColumnParallelLinear
and len(packed_modules_list) == 2
)
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"""
ColumnParallelLinear layer that is specifically designed for
qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
only contains a single LoRA within their qkv_proj layer.
During inference with Tensor Parallel, the weights of lora_b
must be accurately partitioned according to the respective ranks.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
self.q_proj_total_size = (
self.base_layer.total_num_heads * self.base_layer.head_size
)
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
self.kv_proj_shard_size = (
self.base_layer.num_kv_heads * self.base_layer.head_size
)
self.kv_proj_total_size = (
self.base_layer.total_num_kv_heads * self.base_layer.head_size
)
# There is only one LoRA layer
self.n_slices = 1
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
lora_b_q = lora_b[
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
* (self.q_shard_id + 1),
:,
]
k_offset = self.q_proj_total_size
lora_b_k = lora_b[
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
:,
]
v_offset = k_offset + self.kv_proj_total_size
lora_b_v = lora_b[
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
:,
]
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
return lora_b
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
packed together in qkv proj fashion
(q_proj + k_proj + v_proj -> qkv_proj).
This means we have 3 LoRAs, each applied to one slice of the layer.
Q slice may have different shape than K and V slices (which both have
the same shape).
"""
def __init__(self, base_layer: QKVParallelLinear) -> None:
super().__init__(base_layer)
# There are three LoRA layer.
self.n_slices = len(self.base_layer.output_sizes)
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
self.kv_proj_shard_size = (
self.base_layer.num_kv_heads * self.base_layer.head_size
)
self.q_shard_id = self.tp_rank
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
self.output_slices = (
self.q_proj_shard_size,
self.kv_proj_shard_size,
self.kv_proj_shard_size,
)
self.output_ids = (
self.q_shard_id,
self.kv_shard_id,
self.kv_shard_id,
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""
The main reason for overloading this function is to handle inconsistent
weight dimensions in qkv lora.
"""
super().create_lora_weights(max_loras, lora_config, model_config)
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
# These following layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
# their `lora_a` and `lora_b` have different sharding patterns. After
# completing the `lora_a` GEMM , a gather operation is performed.
# Therefore, the sharding of `lora_a` only needs to correspond with the
# gather operation.
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx : start_idx + shard_size, :]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[0] is not None
else None,
lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
if lora_a[1] is not None
else None,
]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
"""
Differs from QKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
lora_a = lora_a[start_idx : start_idx + shard_size, :]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
"""
Differs from MergedQKVParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def slice_lora_a(
self, lora_a: list[torch.Tensor | None]
) -> list[torch.Tensor | None]:
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
if lora_a[0] is not None
else None,
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
if lora_a[1] is not None
else None,
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
if lora_a[2] is not None
else None,
]
return lora_a
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
return _mcp_apply(x, bias, self)
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)
class MergedColumnParallelLinearVariableSliceWithLoRA(
MergedColumnParallelLinearWithLoRA
):
"""MergedColumnParallelLinear with variable number of slices (3+).
This handles cases where the checkpoint has a single weight for the whole
module (not split into slices), but the layer itself has multiple slices.
"""
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# Support MergedColumnParallelLinear with 3 or more slices
# (2 slices are handled by MergedColumnParallelLinearWithLoRA)
if type(source_layer) is not MergedColumnParallelLinear:
return False
# If packed_modules_list has 3+ items, use this class
if len(packed_modules_list) >= 3:
return True
# If packed_modules_list has exactly 2 items, let
# MergedColumnParallelLinearWithLoRA handle it
if len(packed_modules_list) == 2:
return False
# If packed_modules_list is empty or has 1 item,
# check the layer's output_sizes.
# This handles cases where the checkpoint has a single weight
# but the layer has multiple slices (3+)
return (
hasattr(source_layer, "output_sizes")
and len(source_layer.output_sizes) >= 3
)
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Override to handle single tensor weights
that need to be split into slices."""
self.reset_lora(index)
# Handle case where checkpoint has single tensor weights
# lora_a shape: (rank, input_size) - same for all slices, duplicate it
if isinstance(lora_a, torch.Tensor):
lora_a = [lora_a] * self.n_slices
# lora_b shape: (total_output_size, rank) -
# split along dim 0 based on output_sizes
if isinstance(lora_b, torch.Tensor):
output_sizes = self.base_layer.output_sizes
lora_b_list = []
start_idx = 0
for output_size in output_sizes:
end_idx = start_idx + output_size
lora_b_list.append(lora_b[start_idx:end_idx, :])
start_idx = end_idx
lora_b = lora_b_list
# Now call parent's set_lora which expects lists
super().set_lora(index, lora_a, lora_b)

View File

@@ -0,0 +1,777 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.distributed.utils import divide
from vllm.lora.layers.base import BaseLayerWithLoRA
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import (
_get_config_dtype_str,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
MarlinExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
FusedMoEModularMethod,
)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
UnfusedOAITritonExperts,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEModularKernel,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoEP,
)
from .utils import _get_lora_device, try_get_optimal_moe_lora_config
class FusedMoEWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: FusedMoE) -> None:
super().__init__()
self.base_layer = base_layer
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.device = _get_lora_device(base_layer)
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
self._inject_lora_into_fused_moe()
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
normalized_config = {}
for key, value in config.items():
if key.islower():
if key.startswith("block_"):
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
else:
normalized_key = key.upper()
else:
normalized_key = key
normalized_config[normalized_key] = value
return normalized_config
def _get_lora_moe_configs(
self,
op_prefix: str,
num_loras: int,
rank: int,
num_slices: int,
M: int,
layer: FusedMoE,
top_k: int,
config_dtype: str,
):
if envs.VLLM_TUNED_CONFIG_FOLDER:
hidden_size = layer.hidden_size
intermediate_size = layer.intermediate_size_per_partition
shrink_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_shrink",
max_loras=num_loras,
batch=M,
hidden_size=hidden_size,
rank=rank,
num_slices=num_slices,
moe_intermediate_size=intermediate_size,
)
expand_config = get_lora_op_configs(
op_type=f"fused_moe_lora_{op_prefix}_expand",
max_loras=num_loras,
batch=M,
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
rank=rank,
num_slices=num_slices,
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
)
else: # fall back to the default config
get_config_func = functools.partial(
try_get_optimal_moe_lora_config,
w1_shape=layer.w13_weight.size(),
w2_shape=layer.w2_weight.size(),
rank=rank,
top_k=top_k,
dtype=config_dtype,
M=M,
block_shape=layer.quant_method.moe_quant_config.block_shape,
)
shrink_config = get_config_func(
op_type=f"fused_moe_lora_{op_prefix}_shrink"
)
expand_config = get_config_func(
op_type=f"fused_moe_lora_{op_prefix}_expand"
)
shrink_config = self._normalize_keys(shrink_config)
expand_config = self._normalize_keys(expand_config)
return shrink_config, expand_config
def _inject_lora_into_fused_moe(self):
moe_state_dict = {}
top_k = self.base_layer.top_k
self.base_layer.ensure_moe_quant_config_init()
quant_config = self.base_layer.quant_method.moe_quant_config
if getattr(self.base_layer.quant_method, "supports_internal_mk", False):
# Use the existing modular kernel from the quant method
m_fused_moe_fn = self.base_layer.quant_method.moe_mk
# Don't let the kernel own shared experts so the runner can
# overlap them with routed experts via a separate CUDA stream.
m_fused_moe_fn.shared_experts = None
else:
# Create a new modular kernel via select_gemm_impl.
# Don't pass shared_experts to the kernel so the runner can
# overlap them with routed experts via a separate CUDA stream.
prepare_finalize = MoEPrepareAndFinalizeNoEP()
m_fused_moe_fn = FusedMoEModularKernel(
prepare_finalize,
self.base_layer.quant_method.select_gemm_impl(
prepare_finalize, self.base_layer
),
)
if quant_config.use_mxfp4_w4a16:
assert isinstance(
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
)
else:
assert isinstance(m_fused_moe_fn.fused_experts, TritonExperts)
def fwd_decorator(layer, func):
def wrapper(*args, **kwargs):
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
moe_state_dict["expert_map"] = kwargs["expert_map"]
moe_state_dict["apply_router_weight_on_input"] = kwargs[
"apply_router_weight_on_input"
]
result = func(*args, **kwargs)
return result
return wrapper
def act_decorator(layer, func):
def wrapper(*args, **kwargs):
_, output, input = args
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
curr_topk_ids = moe_state_dict["topk_ids"]
expert_map = moe_state_dict["expert_map"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w13",
num_loras=self.max_loras,
rank=max_lora_rank,
num_slices=self._w13_slices,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
# SPARSITY_FACTOR is a heuristic margin ensuring tokens * top_k
# activates only a small fraction of total experts * loras.
SPARSITY_FACTOR = 8
naive_block_assignment = (
expert_map is None
and num_tokens * top_k * SPARSITY_FACTOR
<= self.base_layer.local_num_experts * self.max_loras
)
# get the block size of m from customized config or default config
(
token_lora_mapping,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
) = self.punica_wrapper.moe_lora_align_block_size(
curr_topk_ids,
num_tokens,
shrink_config["BLOCK_SIZE_M"],
self.base_layer.local_num_experts,
self.max_loras,
self.adapter_enabled,
expert_map,
naive_block_assignment=naive_block_assignment,
)
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
moe_state_dict["expert_ids_lora"] = expert_ids_lora
moe_state_dict["num_tokens_post_padded_lora"] = (
num_tokens_post_padded_lora
)
moe_state_dict["token_lora_mapping"] = token_lora_mapping
if sorted_token_ids_lora is not None:
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(
self.max_loras, -1
)
#
self.punica_wrapper.add_lora_fused_moe(
input.view(-1, top_k, input.shape[-1]),
hidden_states,
self.w13_lora_a_stacked,
self.w13_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
fully_sharded=self.fully_sharded,
token_lora_mapping=token_lora_mapping,
)
result = func(*args, **kwargs)
moe_state_dict["intermediate_cache2"] = output
return result
return wrapper
def moe_sum_decorator(layer, func):
def wrapper(*args, **kwargs):
hidden_states = moe_state_dict["hidden_states"]
topk_weights = moe_state_dict["topk_weights"]
config_dtype = _get_config_dtype_str(
dtype=hidden_states.dtype,
use_fp8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
)
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
num_tokens = hidden_states.size(0)
M = min(num_tokens, CHUNK_SIZE)
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
shrink_config, expand_config = self._get_lora_moe_configs(
op_prefix="w2",
num_loras=self.max_loras,
rank=max_lora_rank,
num_slices=1,
M=M,
layer=layer,
top_k=top_k,
config_dtype=config_dtype,
)
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
expert_ids_lora = moe_state_dict["expert_ids_lora"]
num_tokens_post_padded_lora = moe_state_dict[
"num_tokens_post_padded_lora"
]
token_lora_mapping = moe_state_dict.get("token_lora_mapping")
if sorted_token_ids_lora is not None:
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
sorted_token_ids_lora = sorted_token_ids_lora.view(
self.max_loras, -1
)
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
self.punica_wrapper.add_lora_fused_moe(
intermediate_cache3,
intermediate_cache2,
self.w2_lora_a_stacked,
self.w2_lora_b_stacked,
topk_weights,
sorted_token_ids_lora,
expert_ids_lora,
num_tokens_post_padded_lora,
max_lora_rank,
top_k,
shrink_config, ## pass the shrink config
expand_config, ## pass the expand config
self.adapter_enabled,
True,
fully_sharded=self.fully_sharded,
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
token_lora_mapping=token_lora_mapping,
)
result = func(*args, **kwargs)
return result
return wrapper
fused_experts = m_fused_moe_fn.fused_experts
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
fused_experts.activation = act_decorator(
self.base_layer, fused_experts.activation
)
fused_experts.moe_sum = moe_sum_decorator(
self.base_layer, fused_experts.moe_sum
)
# TODO(bnell): find a less intrusive way to handle this.
self.base_layer._replace_quant_method(
FusedMoEModularMethod(self.base_layer.quant_method, m_fused_moe_fn)
)
def _create_lora_a_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
):
self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank
if not self.fully_sharded
else divide(lora_config.max_lora_rank, self.tp_size),
self.base_layer.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
lora_config.max_lora_rank,
self.base_layer.intermediate_size_per_partition,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
# to create a dummy LoRA weights.
# TODO Optimize this section
self.lora_a_stacked = []
self.lora_b_stacked = []
for lora_id in range(max_loras):
for experts_id in range(self.base_layer.local_num_experts):
# For gated MoE: gate_proj (w1), down_proj (w2), up_proj (w3)
# For non-gated MoE: up_proj (w1), down_proj (w2)
self.lora_a_stacked.append(
self.w13_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_a_stacked.append(
self.w2_lora_a_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[0][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w2_lora_b_stacked[0][lora_id][experts_id]
)
# Only add w3 (up_proj) for gated MoE (_w13_slices == 2)
if self._w13_slices == 2:
self.lora_a_stacked.append(
self.w13_lora_a_stacked[1][lora_id][experts_id]
)
self.lora_b_stacked.append(
self.w13_lora_b_stacked[1][lora_id][experts_id]
)
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w13_lora_a
# w13_lora_a shape (num_experts,rank,input_size)
current_lora_rank = w13_lora_a.shape[1]
assert current_lora_rank % self.tp_size == 0
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
shard_size = self.w13_lora_a_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_a[:, start_idx:end_idx, :]
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w13_lora_b[:, start_idx:end_idx, :]
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1:
return w2_lora_a
# w2_lora_a shape (num_experts,rank,input_size)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_a[:, :, start_idx:end_idx]
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
"""
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
"""
if self.tp_size == 1 or not self.fully_sharded:
return w2_lora_b
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
# w2_lora_b shape (num_experts,output_size,rank)
shard_size = self.w2_lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
return w2_lora_b[:, start_idx:end_idx, :]
def reset_lora(self, index: int):
"""Resets the lora weights at index back to 0."""
for pos in range(self._w13_slices):
self.w13_lora_a_stacked[pos][index] = 0
self.w13_lora_b_stacked[pos][index] = 0
self.w2_lora_a_stacked[0][index] = 0
self.w2_lora_b_stacked[0][index] = 0
self.adapter_enabled[index] = 0
#
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
self.reset_lora(index)
self.adapter_enabled[index] = 1
num_experts = self.w13_lora_a_stacked[0].shape[1]
w1_lora_a, w2_lora_a, w3_lora_a = lora_a
w1_lora_b, w2_lora_b, w3_lora_b = lora_b
assert (
num_experts
== w1_lora_a.shape[0]
== w2_lora_a.shape[0]
== w3_lora_a.shape[0]
)
slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
self.w13_lora_a_stacked[0][
index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
].copy_(slliced_w1_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][
index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
].copy_(slliced_w1_lora_b, non_blocking=True)
# Only copy w3 (up_proj) for gated MoE (_w13_slices == 2)
if self._w13_slices == 2:
slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
self.w13_lora_a_stacked[1][
index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
].copy_(slliced_w3_lora_a, non_blocking=True)
self.w13_lora_b_stacked[1][
index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
].copy_(slliced_w3_lora_b, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
def forward(self, *args, **kwargs):
return self.base_layer.forward(*args, **kwargs)
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
@property
def _shared_experts(self):
return self.base_layer._shared_experts
@property
def quant_method(self):
return self.base_layer.quant_method
@property
def is_internal_router(self) -> bool:
return self.base_layer.is_internal_router
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
def __init__(self, base_layer):
super().__init__(base_layer)
self._w13_slices = 1
def _create_lora_b_weights(self, max_loras, lora_config):
self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.intermediate_size_per_partition * 2,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
for _ in range(self._w13_slices)
)
self.w2_lora_b_stacked: tuple[torch.Tensor] = (
torch.zeros(
(
max_loras,
self.base_layer.local_num_experts,
self.base_layer.hidden_size
if not self.fully_sharded
else divide(self.base_layer.hidden_size, self.tp_size),
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
),
)
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
"""Initializes lora matrices."""
assert isinstance(model_config, PretrainedConfig)
self._base_model = model_config.architectures[0]
self.max_loras = lora_config.max_loras
self.fully_sharded = lora_config.fully_sharded_loras
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)
self._create_lora_a_weights(max_loras, lora_config)
self._create_lora_b_weights(max_loras, lora_config)
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
if self.tp_size == 1:
return w13_lora_b
# w13_lora_b shape (num_experts,output_size,rank)
shard_size = self.base_layer.intermediate_size_per_partition
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
# HACK: Currently, only GPT-OSS is in interleaved order
if self._base_model == "GptOssForCausalLM":
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
# in the interleaved order, and corresponding LoRA need to be processed.
w1_lora_b = w13_lora_b[:, ::2, :]
w3_lora_b = w13_lora_b[:, 1::2, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
1, 2
)
else:
slice_size = w13_lora_b.shape[1] // 2
w1_lora_b = w13_lora_b[:, :slice_size, :]
w3_lora_b = w13_lora_b[:, slice_size:, :]
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
"""Overwrites lora tensors at index."""
# Make mypy happy
assert isinstance(lora_a, list)
assert isinstance(lora_b, list)
assert len(lora_a) == len(lora_b) == 2
self.reset_lora(index)
self.adapter_enabled[index] = 1
w13_lora_a, w2_lora_a = lora_a
w13_lora_b, w2_lora_b = lora_b
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
self.w13_lora_a_stacked[0][
index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
].copy_(sliced_w13_lora_a, non_blocking=True)
self.w2_lora_a_stacked[0][
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
].copy_(sliced_w2_lora_a, non_blocking=True)
self.w13_lora_b_stacked[0][
index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
].copy_(sliced_w13_lora_b, non_blocking=True)
self.w2_lora_b_stacked[0][
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
].copy_(sliced_w2_lora_b, non_blocking=True)
@property
def w13_input_size(self):
"""
Full size
"""
return self.w13_lora_a_stacked[0].shape[-1]
@property
def w13_output_size(self):
"""
Full size
"""
return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
@property
def w2_input_size(self):
"""
Full size
"""
return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
@property
def w2_output_size(self):
"""
Full size
"""
return self.base_layer.hidden_size
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
"""Returns True if the layer can be replaced by this LoRA layer."""
# source_layer is FusedMoE or SharedFusedMoE
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1

View File

@@ -0,0 +1,206 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
"""
def __init__(
self,
base_layer: LogitsProcessor,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
sharded_to_full_mapping: list[int] | None,
) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.sharded_to_full_mapping = sharded_to_full_mapping
@property
def logits_as_input(self):
return self.base_layer.logits_as_input
@property
def vocab_size(self):
return self.base_layer.vocab_size
@property
def scale(self):
return self.base_layer.scale
@property
def soft_cap(self):
return self.base_layer.soft_cap
@property
def use_all_gather(self):
return self.base_layer.use_all_gather
@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
@property
def include_gpu_probs_tensor(self):
return self.base_layer.include_gpu_probs_tensor
@property
def should_modify_greedy_probs_inplace(self):
return self.base_layer.should_modify_greedy_probs_inplace
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
# TODO: Verify if this condition can be further relaxed
if self.base_layer.vocab_size <= 32000 or self.base_layer.vocab_size > 258048:
raise ValueError(
"When using LoRA, vocab size must be > 32000 and <= 258048"
)
self.lora_a_stacked = torch.zeros(
(
max_loras,
1,
lora_config.max_lora_rank,
self.hidden_size,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.device,
)
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
)
else:
self.sharded_to_full_mapping_gpu = None
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
lora_a, non_blocking=True
)
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: torch.Tensor | None = None,
) -> torch.Tensor | None:
# Get the logits for the next tokens.
if hasattr(lm_head, "base_layer"):
actual_lm_head = lm_head.base_layer
else:
actual_lm_head = lm_head
logits = actual_lm_head.quant_method.apply(actual_lm_head, hidden_states)
if embedding_bias is not None:
logits += embedding_bias
# Gather logits for TP
logits = self.base_layer._gather_logits(logits)
if logits is None:
return None
if self.sharded_to_full_mapping_gpu is not None:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
)
if not current_platform.can_update_inplace():
logits = lora_output
# Remove paddings in vocab (if any).
logits = logits[:, : self.base_layer.vocab_size]
return logits
def forward(self, *args, **kwargs):
return type(self.base_layer).forward(self, *args, **kwargs)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# Special handling for the LogitsProcessor.
return False

View File

@@ -0,0 +1,70 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.layers.linear import ReplicatedLinear
from .base_linear import BaseLinearLayerWithLoRA
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: ReplicatedLinear) -> None:
super().__init__(
base_layer,
)
# To ensure interface compatibility, set to 1 always.
self.output_size = self.base_layer.output_size
self.n_slices = 1
def forward(
self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
"""Forward of ReplicatedLinearWithLoRA
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
# Matrix multiply.
output = self.apply(input_, bias)
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
if not self.base_layer.return_bias:
return output
return output, output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is ReplicatedLinear
def slice_lora_a(
self, lora_a: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora a if splitting for tensor parallelism."""
return lora_a
def slice_lora_b(
self, lora_b: torch.Tensor | list[torch.Tensor | None]
) -> torch.Tensor | list[torch.Tensor | None]:
"""Slice lora b if splitting with tensor parallelism."""
return lora_b

View File

@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.distributed import (
split_tensor_along_last_dim,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.platforms import current_platform
from .base_linear import BaseLinearLayerWithLoRA
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
def __init__(self, base_layer: RowParallelLinear) -> None:
super().__init__(base_layer)
# reset input_size
self.input_size = self.base_layer.input_size_per_partition
self.output_size = self.base_layer.output_size
# There is only one LoRA layer.
self.n_slices = 1
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
shard_size = self.input_size
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_a = lora_a[:, start_idx:end_idx]
return lora_a
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
return lora_b
def forward(
self, input_: torch.Tensor
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# set up backprop all-reduce.
if self.base_layer.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size
)
input_parallel = splitted_input[self.tp_rank].contiguous()
# Matrix multiply.
bias_ = (
None
if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
else self.base_layer.bias
)
output_parallel = self.apply(input_parallel, bias_)
if self.base_layer.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
if not self.base_layer.return_bias:
return output
return output, output_bias
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is RowParallelLinear
# The following layer is based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
shard_size = self.lora_b_stacked[0].shape[2]
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = lora_b[start_idx:end_idx, :]
return lora_b
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
x = x.view(-1, x.shape[-1])
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
buffer = torch.zeros(
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
dtype=torch.float32,
device=x.device,
)
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
buffer, x, self.lora_a_stacked, 1.0
)
if not current_platform.can_update_inplace():
buffer = shrunk_buffer
if self.tp_size > 1:
buffer = tensor_model_parallel_all_reduce(buffer)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
# NOTE offset are based on the rank.
shard_size = self.lora_b_stacked[0].shape[2]
offset_start = self.tp_rank * shard_size
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
output,
buffer,
self.lora_b_stacked,
self.output_slices,
offset_start=offset_start,
add_input=True,
)
if not current_platform.can_update_inplace():
output = lora_output
output = output.view(*out_orig_shape)
return output
@classmethod
@_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
# specifying kwargs so they can be easily accessed in decorator
return super().can_replace_layer(
source_layer=source_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
decorate=False,
)

112
vllm/lora/layers/utils.py Normal file
View File

@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from enum import Enum
import torch
import torch.nn as nn
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.utils.math_utils import next_power_of_2
class LoRAMappingType(Enum):
LANGUAGE = 1
TOWER = 2
CONNECTOR = 3
@dataclass
class LoRAMapping:
index_mapping: tuple[int, ...]
prompt_mapping: tuple[int, ...]
is_prefill: bool = False
type: LoRAMappingType = LoRAMappingType.LANGUAGE
def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
def _get_lora_device(base_layer: nn.Module) -> torch.device:
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
"""Returns the device for where to place the LoRA tensors."""
# unquantizedLinear
if hasattr(base_layer, "weight"):
return base_layer.weight.device
# Compressed Tensor
elif hasattr(base_layer, "weight_packed"):
return base_layer.weight_packed.device
# GPTQ/AWQ
elif hasattr(base_layer, "qweight"):
return base_layer.qweight.device
# MoE layer
elif hasattr(base_layer, "w2_weight"):
return base_layer.w2_weight.device
# MoE Compressed Tensor
elif hasattr(base_layer, "w2_weight_packed"):
return base_layer.w2_weight_packed.device
# MoE GPTQ/AWQ/GGUF
elif hasattr(base_layer, "w2_qweight"):
return base_layer.w2_qweight.device
else:
raise ValueError(f"Unsupported base layer: {base_layer}")
def _not_fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
return can_replace(*args, **kwargs) and condition
return dec
def _fully_sharded_can_replace(can_replace):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def dec(*args, **kwargs):
return (
can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
)
return dec
def try_get_optimal_moe_lora_config(
op_type: str,
w1_shape: tuple[int, ...],
w2_shape: tuple[int, ...],
rank: int,
top_k: int,
dtype: str | None,
M: int,
block_shape: list[int] | None = None,
) -> dict[str, int | None]:
config = try_get_optimal_moe_config(
w1_shape, w2_shape, top_k, dtype, M, block_shape
).copy()
if op_type in [
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w2_shrink",
]:
config["BLOCK_SIZE_N"] = min(
config.get("BLOCK_SIZE_N", 64), next_power_of_2(rank)
)
elif op_type in [
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_expand",
]:
config["BLOCK_SIZE_K"] = max(
16, min(config.get("BLOCK_SIZE_K", 32), next_power_of_2(rank))
)
return config

View File

@@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.config.lora import LoRAConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
from .base import BaseLayerWithLoRA
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
super().__init__()
self.base_layer = base_layer
self.embeddings_slice: tuple[int, int] | None
self.embeddings_weights: torch.Tensor | None
def create_lora_weights(
self,
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> None:
if self.base_layer.num_added_embeddings_per_partition > 0:
# We can start adding lora weights
self.embeddings_weights = self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition # noqa: E501
+ self.base_layer.num_added_embeddings_per_partition
]
self.embeddings_slice = (
self.base_layer.shard_indices.added_vocab_start_index
- self.base_layer.org_vocab_size,
self.base_layer.shard_indices.added_vocab_end_index
- self.base_layer.org_vocab_size,
)
self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition :
].fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
self.lora_a_stacked = torch.zeros(
(
max_loras,
self.base_layer.org_vocab_size,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_b_stacked = torch.zeros(
(
max_loras,
1,
self.base_layer.embedding_dim,
lora_config.max_lora_rank,
),
dtype=lora_config.lora_dtype,
device=self.base_layer.weight.device,
)
self.lora_a_stacked_2d = self.lora_a_stacked.view(
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
self.lora_a_stacked.shape[2],
)
def reset_lora(self, index: int):
self.lora_a_stacked[index] = 0
self.lora_b_stacked[index] = 0
def set_lora(
self,
index: int,
lora_a: torch.Tensor | list[torch.Tensor],
lora_b: torch.Tensor | list[torch.Tensor],
):
assert isinstance(lora_a, torch.Tensor)
assert isinstance(lora_b, torch.Tensor)
self.reset_lora(index)
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
# so we need transpose here
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
lora_a.T, non_blocking=True
)
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
lora_b, non_blocking=True
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NB: Don't use torch.narrow here. torch.narrow triggers some
# Dynamic Shape specialization in torch.compile
num_tokens = x.shape[0]
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
full_lora_a_embeddings = F.embedding(
x + indices_1,
self.lora_a_stacked_2d,
)
full_output = self.base_layer.forward(x)
full_output_org = full_output
if full_output.ndim == 3:
full_output = full_output.view(
full_output.shape[0] * full_output.shape[1], -1
)
if full_lora_a_embeddings.ndim == 3:
full_lora_a_embeddings = full_lora_a_embeddings.view(
full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1],
-1,
)
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_embedding(
full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True
)
if not current_platform.can_update_inplace():
full_output = lora_output
return full_output.view_as(full_output_org)
@classmethod
def can_replace_layer(
cls,
source_layer: nn.Module,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> bool:
return type(source_layer) is VocabParallelEmbedding
@property
def weight(self):
return self.base_layer.weight

243
vllm/lora/lora_model.py Normal file
View File

@@ -0,0 +1,243 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import safetensors
import torch
from vllm.logger import init_logger
from vllm.lora.lora_weights import LoRALayerWeights
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.utils import (
get_lora_id,
is_base_embeddding_weights,
parse_fine_tuned_lora_name,
)
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.model_executor.models.utils import WeightsMapper
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
class LoRAModel:
"""A LoRA fine-tuned model."""
def __init__(
self,
lora_model_id: int,
rank: int,
loras: dict[str, LoRALayerWeights],
) -> None:
"""
Args:
lora_model_id: The integer id for the lora model.
rank: lora rank.
loras: module name -> weights for lora-replaced layers.
"""
self.id = lora_model_id
assert lora_model_id > 0, (
f"a valid lora id should be greater than 0, got {self.id}"
)
self.rank = rank
self.loras: dict[str, LoRALayerWeights] = loras
def clone(self, lora_model_id: int) -> "LoRAModel":
"""Return a copy of the object with different ids.
Will share the underlying tensors."""
return self.__class__(
lora_model_id,
rank=self.rank,
loras=self.loras.copy(),
)
def get_lora(self, module_name: str) -> LoRALayerWeights | None:
"""Get LoRA for a given module by name"""
return self.loras.get(module_name, None)
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
@staticmethod
def _should_skip_module(module_name: str, skip_prefixes: list[str]) -> bool:
"""Check if a module should be skipped based on skip prefixes"""
for prefix in skip_prefixes:
if f".{prefix}" in module_name or module_name.startswith(prefix):
return True
return False
@classmethod
def from_lora_tensors(
cls,
lora_model_id: int,
tensors: dict[str, torch.Tensor],
peft_helper: PEFTHelper,
device: str = "cuda",
dtype: torch.dtype | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
skip_prefixes: list[str] | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a dictionary of tensors."""
pin_memory = str(device) == "cpu" and is_pin_memory_available()
loras: dict[str, LoRALayerWeights] = {}
for tensor_name, tensor in tensors.items():
if is_base_embeddding_weights(tensor_name):
continue
# Skip modules based on model-defined prefixes (e.g., MTP layers)
if skip_prefixes and cls._should_skip_module(tensor_name, skip_prefixes):
continue
module_name, is_lora_a = parse_fine_tuned_lora_name(
tensor_name, weights_mapper
)
if module_name not in loras:
loras[module_name] = LoRALayerWeights.from_config(
module_name, peft_helper
)
if is_lora_a:
if (
"lora_embedding_A" in tensor_name
and model_vocab_size is not None
and model_vocab_size != tensor.shape[1]
):
raise RuntimeError(
f"The embedding LoRA size({tensor.shape[1]}) must be consistent"
f" with the base model's vocabulary size({model_vocab_size})."
)
loras[module_name].lora_a = tensor.to(device=device, dtype=dtype)
if pin_memory:
loras[module_name].lora_a = loras[module_name].lora_a.pin_memory()
else:
loras[module_name].lora_b = tensor.to(device=device, dtype=dtype)
if pin_memory:
loras[module_name].lora_b = loras[module_name].lora_b.pin_memory()
return cls(lora_model_id, peft_helper.r, loras)
@classmethod
def from_local_checkpoint(
cls,
lora_dir: str,
expected_lora_modules: set[str],
peft_helper: PEFTHelper,
*,
lora_model_id: int | None = None,
device: str = "cuda",
dtype: torch.dtype | None = None,
model_vocab_size: int | None = None,
weights_mapper: WeightsMapper | None = None,
tensorizer_config_dict: dict | None = None,
skip_prefixes: list[str] | None = None,
) -> "LoRAModel":
"""Create a LoRAModel from a local checkpoint.
Args:
lora_dir: The local path that has lora data.
expected_lora_modules: Name of modules that are expected to be
replaced by lora.
peft_helper: Loaded lora configuration information.
lora_model_id: LoRA model id. If not given, automatically set by
a global counter.
device: Device where the lora model is loaded.
dtype: dtype of the lora model weights.
skip_prefixes: List of module name prefixes to skip during loading.
Models can define this to skip modules not used in inference
(e.g., MTP layers). Format: ["mtp."]
Returns:
Loaded LoRA Model.
"""
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
lora_pt_file_path = os.path.join(lora_dir, "adapter_model.pt")
tensors: dict[str, torch.Tensor] = {}
unexpected_modules: list[list[str] | str] = []
def check_unexpected_modules(modules: dict):
for lora_module in modules.keys(): # noqa
if is_base_embeddding_weights(lora_module):
continue
# Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
if "base_layer" in lora_module:
continue
# Skip modules based on model-defined prefixes
if skip_prefixes and cls._should_skip_module(
lora_module, skip_prefixes
):
continue
module_name, _ = parse_fine_tuned_lora_name(lora_module, weights_mapper)
# Case for expert lora weights
if ".experts" in module_name:
expert_idx = module_name.find(".experts")
expert_suffix = module_name[expert_idx + 1 :]
if expert_suffix not in expected_lora_modules:
unexpected_modules.append(module_name)
elif module_name.rsplit(".", 1)[-1] not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
if tensorizer_config_dict:
from tensorizer import TensorDeserializer
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
lora_tensor_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_model.tensors"
)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
tensors = TensorDeserializer(
lora_tensor_path,
dtype=tensorizer_config.dtype,
**tensorizer_args.deserialization_kwargs,
)
check_unexpected_modules(tensors)
elif os.path.isfile(lora_tensor_path):
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it wont error and model will be trained with A, B
# loraified. C wont exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path, framework="pt") as f: # type: ignore
# Load tensors if there are only expected modules.
check_unexpected_modules(f)
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path) or os.path.isfile(lora_pt_file_path):
lora_file_path = (
lora_bin_file_path
if os.path.isfile(lora_bin_file_path)
else lora_pt_file_path
)
tensors = torch.load(lora_file_path, map_location=device, weights_only=True)
check_unexpected_modules(tensors)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
return cls.from_lora_tensors(
lora_model_id=get_lora_id() if lora_model_id is None else lora_model_id,
tensors=tensors,
peft_helper=peft_helper,
device=device,
dtype=dtype,
model_vocab_size=model_vocab_size,
weights_mapper=weights_mapper,
skip_prefixes=skip_prefixes,
)

249
vllm/lora/lora_weights.py Normal file
View File

@@ -0,0 +1,249 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence as GenericSequence
import torch
import torch.types
from vllm.lora.peft_helper import PEFTHelper
from vllm.utils.platform_utils import is_pin_memory_available
class LoRALayerWeights:
"""LoRA weights for a layer composed of two low rank matrixes."""
def __init__(
self,
module_name: str,
rank: int,
lora_alpha: int,
lora_a: torch.Tensor,
lora_b: torch.Tensor,
scaling: float | None = None,
) -> None:
self.module_name = module_name
self.rank = rank
self.lora_alpha = lora_alpha
self.lora_a = lora_a
self.lora_b = lora_b
if scaling is None:
self.scaling = self.lora_alpha / self.rank
else:
self.scaling = scaling
def optimize(self) -> "LoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
if self.scaling == 1:
return self
self.lora_b *= self.scaling
self.scaling = 1
return self
@property
def input_dim(self) -> int:
return self.lora_a.shape[1]
@property
def output_dim(self) -> int:
return self.lora_b.shape[0]
@property
def is_packed(self) -> bool:
return False
@classmethod
def from_config(
cls,
module_name: str,
peft_helper: PEFTHelper,
) -> "LoRALayerWeights":
# lora_a and lora_b are set to None for config-based construction
return cls(
module_name,
peft_helper.r,
peft_helper.lora_alpha,
None,
None,
peft_helper.vllm_lora_scaling_factor,
)
@classmethod
def create_dummy_lora_weights(
cls,
module_name: str,
input_dim: int,
output_dim: int,
rank: int,
dtype: torch.dtype,
device: torch.types.Device,
) -> "LoRALayerWeights":
pin_memory = str(device) == "cpu" and is_pin_memory_available()
lora_a = torch.zeros(
[rank, input_dim], dtype=dtype, device=device, pin_memory=pin_memory
)
lora_b = torch.zeros(
[output_dim, rank], dtype=dtype, device=device, pin_memory=pin_memory
)
return cls(
module_name,
rank=rank,
lora_alpha=1,
lora_a=lora_a,
lora_b=lora_b,
)
class PackedLoRALayerWeights(LoRALayerWeights):
"""LoRA used for packed layers (eg. qkv_proj)."""
def __init__(
self,
module_name: str,
rank: int,
lora_alphas: list[int | None],
lora_a: list[torch.Tensor | None],
lora_b: list[torch.Tensor | None],
scaling: list[float] | None = None,
) -> None:
super().__init__(
module_name=module_name,
rank=rank,
lora_alpha=0,
lora_a=lora_a,
lora_b=lora_b,
scaling=scaling, # type: ignore
)
self.lora_alphas = lora_alphas
if scaling is None:
self.scaling = [ # type: ignore
lora_alpha / self.rank # type: ignore # noqa
for lora_alpha in self.lora_alphas
]
@classmethod
def pack(
cls, loras: GenericSequence["LoRALayerWeights | None"]
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
for lora in loras:
if lora is None:
continue
lora.optimize()
rank = first_lora.rank
module_name = first_lora.module_name
obj = cls(
module_name,
rank,
[lora.lora_alpha if lora is not None else None for lora in loras],
[lora.lora_a if lora is not None else None for lora in loras],
[lora.lora_b if lora is not None else None for lora in loras],
scaling=[
1 if lora is not None else None # type: ignore
for lora in loras
],
)
return obj
@classmethod
def pack_moe(
cls,
loras: GenericSequence["LoRALayerWeights | None"],
module_name: str,
is_non_gated_moe: bool = False,
) -> "PackedLoRALayerWeights":
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
"""
first_lora = next(lora for lora in loras if lora is not None)
assert first_lora is not None
rank = first_lora.rank
lora_alpha = first_lora.lora_alpha
assert len(loras) % 3 == 0
w1_lora_a_lst = []
w2_lora_a_lst = []
w3_lora_a_lst = []
w1_lora_b_lst = []
w2_lora_b_lst = []
w3_lora_b_lst = []
# TODO: Consider the case where some experts don't have LoRA added.
for eid in range(len(loras) // 3):
w1_lora = loras[eid * 3]
w2_lora = loras[eid * 3 + 1]
w3_lora = loras[eid * 3 + 2]
# For non-gated MoE, w3 is not used, so we use w1's LoRA weights
# This is determined by checking the expert mapping (get_expert_mapping)
# which indicates when ckpt_up_proj_name is empty.
if w3_lora is None and is_non_gated_moe:
w3_lora = w1_lora
assert w1_lora is not None
assert w2_lora is not None
assert w3_lora is not None
w1_lora_a_lst.append(w1_lora.lora_a)
w2_lora_a_lst.append(w2_lora.lora_a)
w3_lora_a_lst.append(w3_lora.lora_a)
w1_lora_b_lst.append(w1_lora.lora_b)
w2_lora_b_lst.append(w2_lora.lora_b)
w3_lora_b_lst.append(w3_lora.lora_b)
w1_lora_a = torch.stack(w1_lora_a_lst, dim=0) # (num_experts,rank,input_size)
w2_lora_a = torch.stack(w2_lora_a_lst, dim=0)
w1_lora_b = torch.stack(w1_lora_b_lst, dim=0) # (num_experts,output_size,rank)
w2_lora_b = torch.stack(w2_lora_b_lst, dim=0)
# All w1, w2, w3 have the same scaling factor.
scaling = lora_alpha / rank
last_scaling = scaling
if is_non_gated_moe:
# For non-gated MoE, reuse w1 tensors for w3 to avoid memory waste
# w3_lora_a_lst and w3_lora_b_lst are not relevant in this case
w3_lora_a = w1_lora_a
w3_lora_b = w1_lora_b
# For non-gated MoE, avoid double-scaling by setting w3's scaling to 1.
last_scaling = 1.0
else:
w3_lora_a = torch.stack(w3_lora_a_lst, dim=0)
w3_lora_b = torch.stack(w3_lora_b_lst, dim=0)
obj = cls(
module_name,
rank,
[lora_alpha, lora_alpha, lora_alpha],
[w1_lora_a, w2_lora_a, w3_lora_a],
[w1_lora_b, w2_lora_b, w3_lora_b],
scaling=[scaling, scaling, last_scaling],
)
return obj
def optimize(self) -> "PackedLoRALayerWeights":
"""Optimize the LoRA by merging the scaling into lora_b."""
for i in range(len(self.lora_b)):
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
continue
self.lora_b[i] *= self.scaling[i] # type: ignore
self.scaling[i] = 1 # type: ignore
return self
@property
def input_dim(self) -> int:
raise NotImplementedError()
@property
def output_dim(self) -> int:
raise NotImplementedError()
@property
def is_packed(self) -> bool:
return True

905
vllm/lora/model_manager.py Normal file
View File

@@ -0,0 +1,905 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable
from typing import TypeVar
import regex as re
import torch
from torch import nn
from vllm.config import VllmConfig
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.lora.layers import (
BaseLayerWithLoRA,
FusedMoE3DWithLoRA,
LoRAMapping,
LoRAMappingType,
)
from vllm.lora.lora_model import LoRAModel
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
from vllm.lora.punica_wrapper import PunicaWrapperBase, get_punica_wrapper
from vllm.lora.utils import (
from_layer,
from_layer_logits_processor,
get_supported_lora_modules,
is_moe_model,
process_packed_modules_mapping,
replace_submodule,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.encoder_budget import MultiModalBudget
from vllm.utils.cache import LRUCache
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
T = TypeVar("T")
DEFAULT_LANGUAGE_WRAPPER_KEY = "language_model"
class AdapterLRUCache(LRUCache[int, T]):
def __init__(self, capacity: int, deactivate_fn: Callable[[int], object]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn
def _on_remove(self, key: int, value: T | None):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)
class LoRAModelManager:
"""A manager that manages multiple LoRA-fine-tuned models."""
def __init__(
self,
model: SupportsLoRA,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
vllm_config: VllmConfig | None = None,
):
"""Create a LoRAModelManager and adapter for a given model.
Args:
model: the model to be adapted.
max_num_seqs: the maximum number of sequences model can run in a
single batch.
max_num_batched_tokens: the maximum number of tokens model can run
in a single batch.
vocab_size: the vocab size of the model.
lora_config: the LoRA configuration.
"""
self.model: SupportsLoRA = model
self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, (
f"No supported LoRA modules found in {self.model.__class__.__name__}."
)
self._registered_adapters: dict[int, LoRAModel] = {}
# Dict instead of a set for compatibility with LRUCache.
self._active_adapters: dict[int, None] = {}
self.adapter_type = "LoRA"
self.lora_config = lora_config
self.device = device
self.max_num_seqs = max_num_seqs
assert self.capacity >= self.lora_slots
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
self.lora_index_to_id: list[int | None] = [None] * self.lora_slots
self.vocab_size = vocab_size
self.packed_modules_mapping = process_packed_modules_mapping(self.model)
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: dict[str, list[str]] = {}
self.modules: dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a set for compatibility with LRUCache.
self._last_mapping: LoRAMapping | None = None
is_moe = is_moe_model(self.model)
self._is_3d_moe_model = is_moe and self.model.is_3d_moe_weight
self._is_non_gated_moe = is_moe and self.model.is_non_gated_moe
self._init_punica_wrapper(max_num_batched_tokens, vllm_config)
self._create_lora_modules()
self.model.lora_manager = self
def _init_punica_wrapper(
self, max_num_batched_tokens: int, vllm_config: VllmConfig
) -> None:
# Used to indicate whether the model is a multimodal model
self.supports_mm: bool = (
supports_multimodal(self.model)
# In case the model only supports LoRA for
# text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")
)
self.punica_wrapper_mapping: dict[str, PunicaWrapperBase] = {}
if self.supports_mm:
self._maybe_init_mm(vllm_config, max_num_batched_tokens)
else:
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
lora_config=self.lora_config,
)
self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY] = (
llm_punica_wrapper
)
def _maybe_init_mm(
self,
vllm_config: VllmConfig,
max_num_batched_tokens: int,
) -> None:
mm_registry = MULTIMODAL_REGISTRY
self.supports_tower_connector_lora = False
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
# Only one language model can be included in the model.
assert len(self.mm_mapping.language_model) == 1
# Language model punica wrapper
llm_punica_wrapper = get_punica_wrapper(
max_num_batched_tokens,
max_batches=self.max_num_seqs,
device=self.device,
lora_config=self.lora_config,
)
lm_prefix = self.mm_mapping.language_model[0]
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora:
self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.model, "get_num_mm_encoder_tokens"
)
if not self.supports_tower_connector_lora:
return
logger.warning(
"LoRA for the tower and connector of multimodal models is "
"experimental and may contain bugs. Please report any related issues on "
"GitHub if you encounter them."
)
mm_budget = MultiModalBudget(vllm_config, mm_registry)
limit_per_prompt = max(mm_budget.mm_max_items_per_prompt.values())
num_encoder_tokens = self.model.get_num_mm_encoder_tokens(
mm_budget.get_encoder_budget()
)
# Tower wrappers
tower_punica_wrapper = get_punica_wrapper(
num_encoder_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
lora_config=self.lora_config,
)
for prefix in self.mm_mapping.tower_model:
self.punica_wrapper_mapping[prefix] = tower_punica_wrapper
# Use wrapper for connector if present.
if self.mm_mapping.connector:
if hasattr(self.model, "get_num_mm_connector_tokens"):
connector_tokens = self.model.get_num_mm_connector_tokens(
num_encoder_tokens
)
connector_punica_wrapper = get_punica_wrapper(
connector_tokens,
max_batches=self.max_num_seqs * limit_per_prompt,
device=self.device,
lora_config=self.lora_config,
)
for prefix in self.mm_mapping.connector:
self.punica_wrapper_mapping[prefix] = connector_punica_wrapper
else:
logger.warning_once(
"Connector LoRA support disabled: model does not implement "
"get_num_mm_connector_tokens(). This method is required to "
"determine the connector's token budget for LoRA operations."
)
def __len__(self) -> int:
return len(self._registered_adapters)
@property
def capacity(self) -> int:
return self.lora_config.max_cpu_loras
@property
def lora_slots(self) -> int:
return self.lora_config.max_loras
@property
def adapter_slots(self) -> int:
return self.lora_slots
def activate_adapter(
self,
lora_id: int,
) -> bool:
"""Move LoRA into a GPU buffer to be used in the forward pass."""
if lora_id in self._active_adapters:
return False
first_free_slot = next(
(
(i, lora_id)
for i, lora_id in enumerate(self.lora_index_to_id)
if lora_id is None
),
None,
)
if first_free_slot is None:
raise ValueError("No free lora slots")
index, _ = first_free_slot
self._active_adapters[lora_id] = None
lora_model = self._registered_adapters[lora_id]
logger.debug(
"Activating LoRA. int id: %d, slot index: %d", lora_model.id, index
)
self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items():
module_lora = self._get_lora_layer_weights(lora_model, module_name)
if not module_lora:
module.reset_lora(index)
continue
module.set_lora(
index,
module_lora.lora_a,
module_lora.lora_b,
)
return True
def _deactivate_adapter(self, lora_id: int):
try:
index = self.lora_index_to_id.index(lora_id)
self.lora_index_to_id[index] = None
except ValueError:
pass
def _add_adapter(self, lora: LoRAModel):
self._create_merged_loras_inplace(lora)
self._registered_adapters[lora.id] = lora
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
raise NotImplementedError(
"Pinning is not supported in LoRAModelManager. "
"Use LRUCacheLoRAModelManager for pinning"
) # type: ignore
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
# Default to the main language model wrapper
if not (self.supports_mm and self.supports_tower_connector_lora):
target_prefix = (
self.mm_mapping.language_model[0]
if self.supports_mm
else DEFAULT_LANGUAGE_WRAPPER_KEY
)
elif mapping.type == LoRAMappingType.TOWER and self.mm_mapping.tower_model:
target_prefix = self.mm_mapping.tower_model[0]
elif mapping.type == LoRAMappingType.CONNECTOR and self.mm_mapping.connector:
target_prefix = self.mm_mapping.connector[0]
else:
target_prefix = self.mm_mapping.language_model[0]
punica_wrapper = self._get_punica_wrapper(target_prefix)
assert punica_wrapper is not None
punica_wrapper.update_metadata(
mapping,
self.lora_index_to_id,
self.lora_slots + 1,
self.vocab_size,
)
def remove_all_adapters(self):
"""Remove all LoRAModels from the manager."""
self._registered_adapters.clear()
self.lora_index_to_id = [None] * self.lora_slots
self._active_adapters.clear()
def _create_lora_modules(self):
def _parent_module(module_name: str) -> str:
# module name is a dot separated name.
# for example:
# - given an input 'x.y.z' return 'x.y'
# - given an input 'x' return ''
return module_name.rpartition(".")[0]
for module_name, module in self.model.named_modules(remove_duplicate=False):
if isinstance(module, PPMissingLayer):
continue
if not self._match_target_modules(module_name):
continue
punica_wrapper = self._get_punica_wrapper(module_name)
if punica_wrapper is None:
logger.warning(
"Regarding %s, vLLM currently only supports adding LoRA to"
" language model, %s will be ignored.",
self.model.__class__.__name__,
module_name,
)
continue
# TODO: Remove this restriction
# peft error when generating LoRA adapter with "gate" module:
# "Target module NemotronHTopkRouter() is not supported."
# Working LoRA adapter was created using peft with:
# LoraConfig(target_modules="all-linear", ...)
if self._is_non_gated_moe and module_name.endswith("mixer.gate"):
logger.debug_once(
"LoRA is not supported for non-gated MoE gate module."
" %s will be ignored.",
module_name,
scope="local",
)
continue
parts = module_name.split(".")[-1]
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
if isinstance(module, FusedMoE):
# packed_moduled_lst is used here to just determine whether to
# instantiate FusedMoE3DWithLoRA or FusedMoEWithLoRA, and the
# difference between these two LoRA layers is whether the
# LoRA weights of w1 and w3 have already been fused on disk.
packed_moduled_lst = ["w13"] if self._is_3d_moe_model else ["w1", "w3"]
new_module = replace_submodule(
self.model,
module_name,
from_layer(
module,
self.lora_slots,
self.lora_config,
packed_moduled_lst,
self.model.config,
),
)
# (yard1): TODO make this more robust
if "lm_head" in module_name:
logits_processor_module_name = "logits_processor"
parent_module = _parent_module(module_name)
if parent_module:
logits_processor_module_name = (
f"{parent_module}.{logits_processor_module_name}"
)
logits_processor_module = self.model.get_submodule(
logits_processor_module_name
)
new_module = replace_submodule(
self.model,
logits_processor_module_name,
from_layer_logits_processor(
logits_processor_module,
module,
self.lora_slots,
self.lora_config,
self.model.config,
),
)
# In some models, especially multimodal ones, layers with the same
# name may have different types, such as nn.Linear and
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
# LoRA layers, leading to assertion error. The following check
# aims to prevent this error
if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA):
continue
self.register_module(module_name, new_module)
self._register_packed_modules(module_name)
# All lora layers share the same punica_wrapper based on reference.
new_module.set_mapping(punica_wrapper)
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
assert isinstance(module, BaseLayerWithLoRA), (
f"Module {module_name} must be a BaseLayerWithLoRA instance, "
f"got {type(module)}"
)
self.modules[module_name] = module
@staticmethod
def _pad_lora_pairs_to_triplets(
loras: list[LoRALayerWeights | None],
) -> list[LoRALayerWeights | None]:
"""Pad LoRA weight pairs to triplets for non-gated MoE.
For non-gated MoE, each expert has 2 entries (w1, w2) that need to be
padded to triplets (w1, w2, None) to match pack_moe expectations.
"""
assert len(loras) % 2 == 0, "Expected pairs of LoRA weights for non-gated MoE."
padded: list[LoRALayerWeights | None] = []
for i in range(0, len(loras), 2):
padded.extend(loras[i : i + 2])
padded.append(None)
return padded
def create_dummy_lora(
self,
lora_id: int,
rank: int,
embedding_modules: dict[str, str] | None = None,
) -> LoRAModel:
"""Create zero-initialized LoRAModel for warmup."""
model = LoRAModel(lora_id, rank, {})
for module_name, module in self.model.named_modules():
if (
not self._match_target_modules(module_name)
or not isinstance(module, BaseLayerWithLoRA)
or self._get_punica_wrapper(module_name) is None
):
continue
parts = module_name.split(".")
if module_name not in self.packed_modules:
assert embedding_modules is not None
if parts[-1] in embedding_modules:
# Special-case lm_head: wrapped by LogitsProcessorWithLoRA.
# LoRA input dim is hidden_size, output dim is vocab size.
# LogitsProcessorWithLoRA handles extra vocab size directly.
if parts[-1] == "lm_head":
input_dim = module.lora_a_stacked[0].shape[-1]
output_dim = module.lora_b_stacked[0].shape[-2]
else:
input_dim = (
module.base_layer.org_vocab_size
if hasattr(module.base_layer, "org_vocab_size")
else module.base_layer.weight.shape[1]
)
output_dim = (
module.base_layer.embedding_dim
if hasattr(module.base_layer, "embedding_dim")
else module.base_layer.weight.shape[0]
)
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
input_dim,
output_dim,
rank,
module.lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
elif module.__class__.__name__ == "FusedMoE3DWithLoRA":
# Case for 3D moe model
# w2
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w2_input_size,
module.w2_output_size,
rank * module.w2_lora_a_stacked[0].shape[1], # rank*num_experts
module.w2_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
# w13
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.w13_input_size,
module.w13_output_size,
rank
* module.w13_lora_a_stacked[0].shape[1], # rank*num_experts
module.w13_lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name + ".base_layer"] = lora
else:
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name,
module.lora_a_stacked[0].shape[-1],
module.lora_b_stacked[0].shape[-2],
rank,
module.lora_a_stacked[0].dtype,
"cpu",
)
model.loras[module_name] = lora
else:
parts = module_name.split(".")
replacements = self.packed_modules_mapping[parts[-1]]
subloras: list[LoRALayerWeights | None] = []
for i, r in enumerate(replacements):
lora = LoRALayerWeights.create_dummy_lora_weights(
module_name + "." + r,
module.lora_a_stacked[i].shape[-1],
module.lora_b_stacked[i].shape[-2],
rank,
module.lora_a_stacked[i].dtype,
"cpu",
)
subloras.append(lora)
if module.__class__.__name__ == "FusedMoEWithLoRA":
# For non-gated MoE, pad subloras to 3 elements per expert
# to match pack_moe expectations (w1, w2, None for w3)
if self._is_non_gated_moe and len(subloras) > 0:
subloras = self._pad_lora_pairs_to_triplets(subloras)
lora = PackedLoRALayerWeights.pack_moe(
subloras, module_name, is_non_gated_moe=self._is_non_gated_moe
)
else:
lora = PackedLoRALayerWeights.pack(subloras)
model.loras[module_name] = lora
return model
def _match_target_modules(self, module_name: str):
return any(
re.match(
r".*\.{target_module}$".format(target_module=target_module), module_name
)
or target_module == module_name
for target_module in self.supported_lora_modules
)
def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None:
"""
Determine whether this module supports LoRA and which wrapper to use.
"""
# For language model (early return)
if not self.supports_mm:
return self.punica_wrapper_mapping[DEFAULT_LANGUAGE_WRAPPER_KEY]
# For multimodal model
# NOTE Sort by prefix length (descending) to match the longest prefix first
# e.g., 'visual.merger' should match 'visual.merger' instead of 'visual.'
for prefix in sorted(self.punica_wrapper_mapping.keys(), key=len, reverse=True):
if module_name.startswith(prefix):
return self.punica_wrapper_mapping[prefix]
return None
def _register_packed_modules(self, module_full_name: str) -> None:
parts = module_full_name.split(".")
module_name = parts[-1]
replacements = self.packed_modules_mapping.get(module_name, [])
# When replacements is less than or equal to 1, it indicates that this
# module is not a packed module.
if len(replacements) <= 1:
return
prefix = ".".join(parts[:-1])
self.packed_modules[module_full_name] = [
prefix + "." + r if prefix else r for r in replacements
]
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras: list[LoRALayerWeights | None] = []
replaced_module: set[str] = set()
has_replacement = False
for r in new_module_names:
lora = self._get_lora_layer_weights(lora_model, r)
replacement_loras.append(lora)
if lora:
has_replacement = True
replaced_module.add(r)
if not has_replacement:
continue
for i in range(len(replacement_loras)):
if replacement_loras[i]:
continue
replacement_loras[i] = None
# HACK Temporary solution for the pool model.
if self.is_pooling_model and not lora_model.check_lora_name(module_name):
replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
module_name = replaced_module_name
if module_name.endswith(".experts"):
if self._is_non_gated_moe and len(replacement_loras) > 0:
replacement_loras = self._pad_lora_pairs_to_triplets(
replacement_loras
)
lora_model.loras[module_name] = PackedLoRALayerWeights.pack_moe(
replacement_loras,
module_name,
is_non_gated_moe=self._is_non_gated_moe,
)
else:
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras
)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)
for lora in lora_model.loras.values():
lora.optimize()
for module_name, module in self.modules.items():
if isinstance(module, FusedMoE3DWithLoRA):
self._stack_moe_lora_weights(lora_model, module, module_name)
first_lora: LoRALayerWeights = next(iter(lora_model.loras.values()))
assert first_lora.lora_a is not None
if isinstance(first_lora.lora_a, list):
lora_device = next(iter(first_lora.lora_a))
else:
lora_device = first_lora.lora_a.device
# Execute pin_memory after LoRA weight merging, mainly because:
# 1. Some MoE models have a large number of LoRA weights. If we
# perform # pin_memory immediately after loading weights, the
# overhead is significant.
# 2. The weight packing above (e.g., pack_moe) may invalidate the
# pin_memory allocation, so we execute it after packing.
pin_memory = str(lora_device) == "cpu" and is_pin_memory_available()
if pin_memory:
for lora in lora_model.loras.values():
if isinstance(lora.lora_a, list):
for index in range(len(lora.lora_a)):
if lora.lora_a[index] is None:
continue
lora.lora_a[index] = lora.lora_a[index].pin_memory()
lora.lora_b[index] = lora.lora_b[index].pin_memory()
else:
lora.lora_a = lora.lora_a.pin_memory()
lora.lora_b = lora.lora_b.pin_memory()
def _stack_moe_lora_weights(
self, lora_model: LoRAModel, module: FusedMoE3DWithLoRA, module_name: str
):
module_lora = self._get_lora_layer_weights(lora_model, module_name)
# Note (gnovack) - If MOE lora weights are not split into
# num_experts chunks, we split them here
if module_lora and torch.is_tensor(module_lora.lora_a):
# Handle PEFT file format where experts.base_layer is the
# gate_up_proj and experts is the down_proj
gate_up_proj_lora = self._get_lora_layer_weights(
lora_model, module_name + ".base_layer"
)
down_proj_lora = module_lora
# FIXME Edge case where LoRA is not added to gate_up_proj
# or down_proj
assert gate_up_proj_lora is not None
assert down_proj_lora is not None
if self._is_3d_moe_model:
num_experts = module.w13_lora_a_stacked[0].shape[1]
# (num_experts,rank,input_size)
gate_up_proj_lora.lora_a = gate_up_proj_lora.lora_a.reshape(
num_experts, -1, gate_up_proj_lora.lora_a.shape[-1]
)
down_proj_lora.lora_a = down_proj_lora.lora_a.reshape(
num_experts, -1, down_proj_lora.lora_a.shape[-1]
)
# (output_size,rank,num_experts)
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.reshape(
gate_up_proj_lora.lora_b.shape[0], -1, num_experts
)
down_proj_lora.lora_b = down_proj_lora.lora_b.reshape(
down_proj_lora.lora_b.shape[0], -1, num_experts
)
# (num_experts,output_size,rank)
gate_up_proj_lora.lora_b = gate_up_proj_lora.lora_b.permute(
2, 0, 1
).contiguous()
down_proj_lora.lora_b = down_proj_lora.lora_b.permute(
2, 0, 1
).contiguous()
module_lora.lora_a = [
gate_up_proj_lora.lora_a,
down_proj_lora.lora_a,
]
module_lora.lora_b = [
gate_up_proj_lora.lora_b,
down_proj_lora.lora_b,
]
else:
# Some 3D MoE models haven't added the `is_3d_moe_weight`
# attribute yet, so fallback here
num_experts = module_lora.lora_a.shape[0] // module_lora.rank
gate_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
up_proj_a = gate_up_proj_lora.lora_a.chunk(num_experts, dim=0)
gate_proj_b = gate_up_proj_lora.lora_b[::2, ...].chunk(
num_experts, dim=-1
)
up_proj_b = gate_up_proj_lora.lora_b[1::2, ...].chunk(
num_experts, dim=-1
)
down_proj_a = down_proj_lora.lora_a.chunk(num_experts, dim=0)
down_proj_b = down_proj_lora.lora_b.chunk(num_experts, dim=-1)
lora_a = []
lora_b = []
for i in range(num_experts):
lora_a.append(gate_proj_a[i])
lora_a.append(down_proj_a[i])
lora_a.append(up_proj_a[i])
lora_b.append(gate_proj_b[i])
lora_b.append(down_proj_b[i])
lora_b.append(up_proj_b[i])
module_lora.lora_a = lora_a
module_lora.lora_b = lora_b
def _get_lora_layer_weights(
self, lora_model: LoRAModel, module_name: str
) -> LoRALayerWeights | None:
org_module_name = module_name
if self.is_pooling_model and not lora_model.check_lora_name(module_name):
# If it's a pool model, and the layer name is not found,
# remove the prefix 'model.' and search again.
module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
org_module_name = module_name
logger.info_once(
"For the pool model, successfully loaded the LoRA weights "
"after removing the prefix 'model.'."
)
return lora_model.get_lora(org_module_name)
def deactivate_adapter(self, adapter_id: int) -> bool:
if adapter_id not in self._active_adapters:
return False
self._deactivate_adapter(adapter_id)
self._active_adapters.pop(adapter_id, None)
return True
def add_adapter(self, adapter: LoRAModel) -> bool:
logger.debug("Adding lora. Model id: %d, int id: %d", adapter.id, adapter.id)
if adapter.id in self._registered_adapters:
return False
if len(self._registered_adapters) >= self.capacity:
raise RuntimeError("No free adapter slots.")
self._add_adapter(adapter)
return True
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
if self._last_mapping != mapping:
self._set_adapter_mapping(mapping)
self._last_mapping = mapping
def remove_adapter(self, adapter_id: int) -> bool:
self.deactivate_adapter(adapter_id)
if adapter_id not in self._registered_adapters:
return False
self._registered_adapters.pop(adapter_id, None)
return True
def list_adapters(self) -> dict[int, LoRAModel]:
return dict(self._registered_adapters)
def get_adapter(self, adapter_id: int) -> LoRAModel | None:
return self._registered_adapters.get(adapter_id)
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]):
super().__init__(capacity, deactivate_lora_fn)
class LRUCacheLoRAModelManager(LoRAModelManager):
"""A model manager that manages multiple LoRAs with LRU cache."""
def __init__(
self,
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
device: torch.device,
vllm_config: VllmConfig | None = None,
):
super().__init__(
model,
max_num_seqs,
max_num_batched_tokens,
vocab_size,
lora_config,
device,
vllm_config,
)
self._registered_adapters: LoRALRUCache = LoRALRUCache(
self.capacity, self.deactivate_adapter
)
self._active_adapters: LoRALRUCache = LoRALRUCache(
self.lora_slots, self._deactivate_adapter
)
def list_adapters(self) -> dict[int, LoRAModel]:
"""List all registered LoRAModels."""
return dict(self._registered_adapters.cache)
def add_adapter(self, lora: LoRAModel) -> bool:
"""Add a LoRAModel to the manager."""
logger.debug("Adding lora. Model id: %d, int id: %d", lora.id, lora.id)
if lora.id not in self._registered_adapters:
self._add_adapter(lora)
was_added = True
else:
# We always touch to update the LRU cache order
self._registered_adapters.touch(lora.id)
was_added = False
return was_added
def activate_adapter(
self,
lora_id: int,
) -> bool:
if (
lora_id not in self._active_adapters
and len(self._active_adapters) >= self.lora_slots
):
self._active_adapters.remove_oldest()
result = super().activate_adapter(lora_id)
# We always touch to update the LRU cache order
self._active_adapters.touch(lora_id)
return result
def remove_oldest_adapter(self) -> bool:
if len(self._registered_adapters) > 0:
self._registered_adapters.remove_oldest()
return True
return False
def pin_adapter(self, lora_id: int) -> bool:
"""Pin a LoRAModel in the manager cache."""
self._pin_lora_in_cpu_cache(lora_id)
self._pin_lora_in_gpu_cache(lora_id)
return True
def _pin_lora_in_cpu_cache(self, lora_id: int):
try:
self._registered_adapters.pin(lora_id)
except ValueError as err:
raise ValueError(
f"Pinning failed. LoRA {lora_id} is not registered."
) from err
def _pin_lora_in_gpu_cache(self, lora_id: int):
if lora_id not in self._active_adapters:
# move lora to gpu if not already active
self.activate_adapter(lora_id)
self._active_adapters.pin(lora_id)
def create_lora_manager(
model: nn.Module,
max_num_seqs: int,
max_num_batched_tokens: int,
vocab_size: int,
lora_config: LoRAConfig,
vllm_config: VllmConfig,
device: torch.device,
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
**kwargs,
) -> LoRAModelManager:
"""Create a LoRA adapter for a given model."""
if not isinstance(model, SupportsLoRA):
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
lora_manager = lora_manager_cls(
model=model,
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
vocab_size=vocab_size,
lora_config=lora_config,
vllm_config=vllm_config,
device=device,
**kwargs,
)
return lora_manager

View File

View File

@@ -0,0 +1,20 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.torch_ops.lora_ops import (
bgmv_expand, # noqa: F401
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
__all__ = [
"bgmv_expand",
"bgmv_expand_slice",
"bgmv_shrink",
"sgmv_expand",
"sgmv_expand_slice",
"sgmv_shrink",
]

View File

@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
def sgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
add_inputs: bool = False,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices, add_inputs)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
# LoRA adapter and model may add different amounts of padding to output
common_len = min(outputs.shape[1], output_tensor.shape[1])
if add_inputs:
output_tensor[:, :common_len] += outputs[:limit, :common_len]
else:
output_tensor[:, :common_len] = outputs[:limit, :common_len]
def sgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
scaling: float,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices, scaling)
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
inputs = inputs.to(dtype=output_tensor.dtype)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
output_tensor[:, : outputs.shape[1]] = scaling * outputs[:]
def sgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
b_seq_start_loc: torch.Tensor,
seq_len_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
batches: int,
max_seq_length: int,
token_nums: int,
slice_offset: int,
slice_size: int,
add_inputs: bool = False,
):
exploded_indices = torch.repeat_interleave(lora_indices_tensor, seq_len_tensor)
bgmv_expand_slice(
inputs,
lora_b_weights,
output_tensor,
exploded_indices,
slice_offset,
slice_size,
add_inputs,
)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
selected_loras = lora_b_weights[lora_indices_tensor].to(dtype=output_tensor.dtype)
inputs = inputs.to(dtype=output_tensor.dtype)
if len(selected_loras.shape) == 4:
selected_loras = selected_loras.squeeze(dim=1)
outputs = torch.einsum("bi, boi -> bo", inputs, selected_loras)
if add_inputs:
output_tensor[:, slice_offset : slice_offset + slice_size] += outputs[:]
else:
output_tensor[:, slice_offset : slice_offset + slice_size] = outputs[:]

View File

@@ -0,0 +1,60 @@
# Multi-LoRA Tuning
**Note**: The LoRA configuration folder should be specified by exporting `VLLM_TUNED_CONFIG_FOLDER=/path/to/configs`.
Without this, the shrink/expand kernels will use default configurations.
## Tuning Process
Multi-lora shrink/expand Triton kernel tuning follows a similar methodology from
[Triton MoE tuning](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_moe.py).
1. Define the searching space. Here is an example of searching space:
```python
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [32, 64, 128, 256]
num_warps_range = [4, 8]
num_stage_range = [2, 3, 4, 5]
num_ctas_range = [1]
split_k_range = [4, 8, 16, 32, 64]
```
2. Get all hidden_state sizes and num_slices that the target model uses for a specific TP size.
For example, you can acquire the info by simply checking
[add_lora_linear](https://github.com/vllm-project/vllm/blob/main/vllm/lora/punica_wrapper/punica_gpu.py#L181):
```python
print(f"x_shape: {x.view(-1, x.shape[-1]).shape}")
print(f"num_slices: {len(output_slices)}")
for i in range(len(output_slices)):
print(f"a{i} shape: {lora_a_stacked[i].shape}")
print(f"b{i} shape: {lora_b_stacked[i].shape}")
print("y_shape", y.shape)
```
3. Benchmark the shrink/expand kernel runtime with different kernel configurations generated from the pre-defined search space
by performing a grid search to find the optimal kernel configuration.
vLLM's [benchmark_lora.py](https://github.com/vllm-project/vllm/blob/main/benchmarks/kernels/benchmark_lora.py)
can be used to search for configurations for different shapes.
## Config Files
### File Naming
| Kernel Type | File Name Template | Example |
|---------------------------|--------------------------------------------|---------------------------------------------|
| shrink | `{gpu_name}_SHRINK.json` | `NVIDIA_H200_SHRINK.json` |
| expand | `{gpu_name}_EXPAND_{add_input}.json` | `NVIDIA_H200_EXPAND_TRUE.json` |
| fused_moe_lora_w13_shrink | `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json` | `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json` |
| fused_moe_lora_w13_expand | `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json` | `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json` |
| fused_moe_lora_w2_shrink | `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json` | `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json` |
| fused_moe_lora_w2_expand | `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json` | `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json` |
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`.
### JSON Structure
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]`,
where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer.

View File

@@ -0,0 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.triton_ops.fused_moe_lora_fp8_op import (
fused_moe_lora_expand_fp8,
fused_moe_lora_fp8,
fused_moe_lora_shrink_fp8,
)
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
fused_moe_lora,
fused_moe_lora_expand,
fused_moe_lora_shrink,
)
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
__all__ = [
"lora_expand",
"lora_shrink",
"LoRAKernelMeta",
"fused_moe_lora",
"fused_moe_lora_shrink",
"fused_moe_lora_expand",
"fused_moe_lora_fp8",
"fused_moe_lora_shrink_fp8",
"fused_moe_lora_expand_fp8",
]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,845 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import supports_pdl
@triton.jit
def _get_lora_id(
lora_ids,
token_lora_mapping_ptr,
lora_idx,
pid_m,
top_k_num,
naive_block_assignment: tl.constexpr,
):
"""Returns lora_id"""
if naive_block_assignment:
token_idx = pid_m // top_k_num
return tl.load(token_lora_mapping_ptr + token_idx)
else:
return tl.load(lora_ids + lora_idx)
@triton.jit
def _get_expert_id(
expert_ids_ptr,
lora_id,
pid_m,
stride_el,
max_loras,
naive_block_assignment: tl.constexpr,
):
"""Returns expert_id"""
if naive_block_assignment:
return tl.load(expert_ids_ptr + pid_m)
else:
ind = lora_id * stride_el + pid_m
return tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
@triton.jit
def _get_token_offs(
sorted_token_ids_ptr,
lora_id,
pid_m,
offs,
stride_tl,
max_loras,
num_valid_tokens,
naive_block_assignment: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
):
"""Returns token offsets"""
if naive_block_assignment:
return tl.where(offs == 0, pid_m, num_valid_tokens)
else:
offs_token_id = pid_m * BLOCK_SIZE_M + offs
token_ind = stride_tl * lora_id + offs_token_id
return tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
)
_LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {}
def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None:
return ptr_tensor
tensor_ptrs = []
for lora_weight in lora_weights:
tensor_ptrs.append(lora_weight.data_ptr())
ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
_LORA_PTR_DICT[key] = ptr_tensor
return _LORA_PTR_DICT.get(key)
def _adjust_kernel_inputs(
num_active_loras: int,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
):
"""
helper function to adjust kernel inputs when sorted_token_ids is None
"""
if sorted_token_ids is None:
stride_tl = 0
stride_el = 0
grid_lora_dim = 1
else:
stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0)
grid_lora_dim = num_active_loras
return grid_lora_dim, stride_tl, stride_el
@triton.jit(
do_not_specialize=[
"num_valid_tokens",
"EM",
"stride_tl",
"stride_el",
"slice_a_size",
"slice_c_size",
]
)
def _fused_moe_lora_kernel(
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
token_lora_mapping_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
num_experts,
top_k_num,
lora_ids,
adapter_enabled,
max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_bl,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_tl,
stride_el,
slice_a_size,
slice_c_size,
# Meta-parameters
num_slice_a: tl.constexpr,
num_slice_c: tl.constexpr,
# top_k_num or 1 depending on input token
# is expanded by top_k or not
token_mapping_factor: tl.constexpr,
# whether use naive block assignment
naive_block_assignment: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
ADD_INPUTS: tl.constexpr,
USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SPLIT_K: tl.constexpr,
USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr,
IS_PRIMARY: tl.constexpr,
):
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)
# calculate pid_m,pid_n
lora_idx = tl.program_id(axis=2)
pid_sk = pid % SPLIT_K
pid_m_n = pid // SPLIT_K
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
offs = tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
# Get lora_id
lora_id = _get_lora_id(
lora_ids,
token_lora_mapping_ptr,
lora_idx,
pid_m,
top_k_num,
naive_block_assignment,
)
if lora_id == -1:
return
moe_enabled = tl.load(adapter_enabled + lora_id)
if moe_enabled == 0:
return
if lora_id >= max_loras:
return
# Non-naive only: check num_tokens_post_padded
if not naive_block_assignment:
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
# Get expert_id
expert_id = _get_expert_id(
expert_ids_ptr,
lora_id,
pid_m,
stride_el,
max_loras,
naive_block_assignment,
)
if expert_id == -1:
return
# Get token offsets
offs_token = _get_token_offs(
sorted_token_ids_ptr,
lora_id,
pid_m,
offs,
stride_tl,
max_loras,
num_valid_tokens,
naive_block_assignment,
BLOCK_SIZE_M,
)
# get a_ptr,b_ptr,c_ptr
cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size
cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty))
cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size
# remove modulo wrap-around
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32)
offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
token_mask = offs_token < num_valid_tokens
# get a_ptrs,b_ptrs
a_ptrs = cur_a_ptr + (
offs_token[:, None] // token_mapping_factor * stride_am
+ offs_k[None, :] * stride_ak
)
b_ptrs = (
cur_b_ptr
+ lora_id * stride_bl
+ expert_id * stride_be
+ offs_k[:, None] * stride_bk
+ offs_bn[None, :] * stride_bn
)
if USE_GDC and IS_PRIMARY:
# GDC launch dependents hints the runtime system to launch dependent kernels.
tl.extra.cuda.gdc_launch_dependents()
# accumulator
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
for k in range(0, grid_k):
k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K)
# GDC wait waits for ALL programs in the prior kernel to complete
# before continuing.
# pre-fetch lora weight
# add (offs_bn < N) mask; optional .ca for B
b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N)
if USE_B_L2_CACHE:
b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca")
else:
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
if USE_GDC and not IS_PRIMARY:
tl.extra.cuda.gdc_wait()
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < k_remaining),
other=0.0,
)
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(c_ptr.dtype.element_ty)
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
if SPLIT_K == 1:
if ADD_INPUTS:
prev = tl.load(c_ptrs, mask=c_mask, other=0.0)
tl.store(c_ptrs, prev + accumulator, mask=c_mask)
else:
tl.store(c_ptrs, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed")
@torch.inference_mode()
def _fused_moe_lora_shrink(
a_intermediate_cache1: torch.Tensor,
# (num_slices, num_tokens, top_k_num, max_lora_rank)
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
lora_a_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor | None, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded: torch.Tensor | None, # (max_loras, )
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
## adding for kernel
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
w1_lora_a_stacked = lora_a_stacked[0]
shrink_config = {
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
"SPLIT_K": split_k,
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
}
b_ptr = _get_ptr(lora_a_stacked, device)
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
num_active_loras, sorted_token_ids, expert_ids
)
grid = lambda META: (
split_k
* triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_a_stacked),
grid_lora_dim,
)
_fused_moe_lora_kernel[grid](
qcurr_hidden_states,
b_ptr,
a_intermediate_cache1,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
N,
K,
EM,
num_tokens,
num_experts,
top_k_num,
lora_ids,
adapter_enabled,
lora_a_stacked[0].shape[0],
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
w1_lora_a_stacked.stride(1),
w1_lora_a_stacked.stride(3),
w1_lora_a_stacked.stride(2),
a_intermediate_cache1.stride(2),
a_intermediate_cache1.stride(3),
stride_tl,
stride_el,
slice_a_size=qcurr_hidden_states.numel(),
slice_c_size=a_intermediate_cache1.numel() // num_slices,
num_slice_a=1,
num_slice_c=num_slices,
token_mapping_factor=1 if mul_routed_weight else top_k_num,
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=False,
ADD_INPUTS=False,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=True,
**shrink_config,
)
@torch.inference_mode()
def _fused_moe_lora_expand(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank)
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor | None, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded: torch.Tensor | None, # (max_loras, )
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
## adding for kernel
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
max_lora_rank: int,
w1_output_dim_size: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
) -> None:
b_ptr = _get_ptr(lora_b_stacked, device)
K = max_lora_rank
N = w1_output_dim_size
w1_lora_b_stacked = lora_b_stacked[0]
a_intermediate_cache1 = a_intermediate_cache1.view(
-1, a_intermediate_cache1.shape[3]
)
expand_config = {
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
"num_warps": num_warps,
"num_stages": num_stages,
"SPLIT_K": 1, # Set split_k = 1 for expand calls
"USE_GDC": use_gdc,
"launch_pdl": use_gdc, # triton kernel metadata
}
grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
num_active_loras, sorted_token_ids, expert_ids
)
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
len(lora_b_stacked),
grid_lora_dim,
)
# Fast path: directly accumulate into the corresponding slice interval of output.
out_view = output[:, :, offset : offset + num_slices * N]
slice_c_size = N * out_view.stride(2)
_fused_moe_lora_kernel[grid](
a_intermediate_cache1,
b_ptr,
out_view,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
N,
K,
EM,
num_tokens,
num_experts,
top_k_num,
lora_ids,
adapter_enabled,
lora_b_stacked[0].shape[0],
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
w1_lora_b_stacked.stride(1),
w1_lora_b_stacked.stride(3),
w1_lora_b_stacked.stride(2),
out_view.stride(1),
out_view.stride(2),
stride_tl,
stride_el,
slice_a_size=a_intermediate_cache1.numel() // num_slices,
slice_c_size=slice_c_size,
num_slice_a=num_slices,
num_slice_c=num_slices,
token_mapping_factor=1,
naive_block_assignment=sorted_token_ids is None,
MUL_ROUTED_WEIGHT=mul_routed_weight,
ADD_INPUTS=True,
USE_B_L2_CACHE=True, # new
IS_PRIMARY=False,
**expand_config,
)
@torch.inference_mode()
def _fused_moe_lora(
output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),)
qcurr_hidden_states: torch.Tensor, # (num_tokens, K,)
lora_a_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, max_lora_rank, K,),...]
lora_b_stacked: list[
torch.Tensor
], # [(max_loras, num_experts, N, max_lora_rank,),...]
topk_weights: torch.Tensor, # (num_tokens, top_k_num)
sorted_token_ids: torch.Tensor | None, # (max_loras, _)
expert_ids: torch.Tensor, # (max_loras, _ ,) or (num_tokens * top_k,)
num_tokens_post_padded: torch.Tensor | None, # (max_loras, )
token_lora_mapping: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
shrink_block_size_k: int,
shrink_group_size_m: int,
shrink_num_warps: int,
shrink_num_stages: int,
shrink_split_k: int,
expand_block_size_m: int,
expand_block_size_n: int,
expand_block_size_k: int,
expand_group_size_m: int,
expand_num_warps: int,
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
fully_sharded: bool = False,
offset: int = 0,
) -> None:
assert len(lora_a_stacked) == len(lora_b_stacked) > 0
assert topk_weights.dim() == qcurr_hidden_states.dim() == 2
if sorted_token_ids is None:
assert expert_ids.dim() == 1
else:
assert sorted_token_ids is not None
assert num_tokens_post_padded is not None
assert (
sorted_token_ids.dim()
== expert_ids.dim()
== topk_weights.dim()
== qcurr_hidden_states.dim()
== 2
)
assert (
sorted_token_ids.shape[0]
== expert_ids.shape[0]
== num_tokens_post_padded.shape[0]
)
assert output.shape[0] == topk_weights.shape[0]
assert top_k_num == topk_weights.shape[1]
device = qcurr_hidden_states.device
num_slices = len(lora_a_stacked)
w1_lora_b_stacked = lora_b_stacked[0]
num_experts = lora_a_stacked[0].shape[1]
N = max_lora_rank
M = topk_weights.shape[0]
K = qcurr_hidden_states.shape[1]
num_tokens = M * top_k_num
w1_output_dim_size = w1_lora_b_stacked.shape[2]
assert shrink_block_size_m == expand_block_size_m
EM = (
sorted_token_ids.shape[1]
if sorted_token_ids is not None
else num_tokens * shrink_block_size_m
)
a_intermediate_cache1 = torch.zeros(
(num_slices, M, top_k_num, max_lora_rank),
dtype=output.dtype,
device=device,
)
use_gdc = supports_pdl(device) and not fully_sharded
_fused_moe_lora_shrink(
a_intermediate_cache1,
qcurr_hidden_states,
lora_a_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
top_k_num,
lora_ids,
adapter_enabled,
## adding for kernel
device,
N,
M,
EM,
K,
num_tokens,
num_experts,
num_slices,
shrink_block_size_m,
shrink_block_size_n,
shrink_block_size_k,
shrink_group_size_m,
shrink_num_warps,
shrink_num_stages,
shrink_split_k,
num_active_loras,
mul_routed_weight,
use_gdc=use_gdc,
)
if fully_sharded:
if max_lora_rank == w1_lora_b_stacked.shape[-1]:
a_intermediate_cache1 = tensor_model_parallel_all_reduce(
a_intermediate_cache1
)
else:
a_intermediate_cache1 = tensor_model_parallel_all_gather(
a_intermediate_cache1
)
# reset max_lora_rank to the full rank after allgather
max_lora_rank = a_intermediate_cache1.shape[-1]
_fused_moe_lora_expand(
output,
a_intermediate_cache1,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
top_k_num,
lora_ids,
adapter_enabled,
## adding for kernel
device,
N,
M,
EM,
K,
num_tokens,
num_experts,
num_slices,
max_lora_rank,
w1_output_dim_size,
expand_block_size_m,
expand_block_size_n,
expand_block_size_k,
expand_group_size_m,
expand_num_warps,
expand_num_stages,
expand_split_k,
num_active_loras,
mul_routed_weight,
offset,
use_gdc=use_gdc,
)
def _fused_moe_lora_fake(
output: torch.Tensor,
qcurr_hidden_states: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
token_lora_mapping: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
shrink_block_size_k: int,
shrink_group_size_m: int,
shrink_num_warps: int,
shrink_num_stages: int,
shrink_split_k: int,
expand_block_size_m: int,
expand_block_size_n: int,
expand_block_size_k: int,
expand_group_size_m: int,
expand_num_warps: int,
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
fully_sharded: bool = False,
offset: int = 0,
) -> None:
return
def _fused_moe_lora_shrink_fake(
a_intermediate_cache1: torch.Tensor,
qcurr_hidden_states: torch.Tensor,
lora_a_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
return
def _fused_moe_lora_expand_fake(
output: torch.Tensor,
a_intermediate_cache1: torch.Tensor,
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
token_lora_mapping: torch.Tensor,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
device: torch.device,
N: int,
M: int,
EM: int,
K: int,
num_tokens: int,
num_experts: int,
num_slices: int,
max_lora_rank: int,
w1_output_dim_size: int,
block_size_m: int,
block_size_n: int,
block_size_k: int,
group_size_m: int,
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="fused_moe_lora",
op_func=_fused_moe_lora,
mutates_args=["output"],
fake_impl=_fused_moe_lora_fake,
)
direct_register_custom_op(
op_name="fused_moe_lora_shrink",
op_func=_fused_moe_lora_shrink,
mutates_args=["a_intermediate_cache1"],
fake_impl=_fused_moe_lora_shrink_fake,
)
direct_register_custom_op(
op_name="fused_moe_lora_expand",
op_func=_fused_moe_lora_expand,
mutates_args=["output"],
fake_impl=_fused_moe_lora_expand_fake,
)
fused_moe_lora = torch.ops.vllm.fused_moe_lora
fused_moe_lora_shrink = torch.ops.vllm.fused_moe_lora_shrink
fused_moe_lora_expand = torch.ops.vllm.fused_moe_lora_expand
except AttributeError:
fused_moe_lora = _fused_moe_lora
fused_moe_lora_shrink = _fused_moe_lora_shrink
fused_moe_lora_expand = _fused_moe_lora_expand

View File

@@ -0,0 +1,340 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Utilities for Punica kernel construction.
"""
from vllm.triton_utils import tl, triton
@triton.jit
def mm_k(
a_ptr,
b_ptr,
ak_stride,
bk_stride,
offset_k,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
USE_GDC: tl.constexpr,
base_k,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
B (k x n), iterate, through the K dimension to compute the partial/complete
matrix block product.
If SPLIT_K == 1, the output m x n product is complete.
If SPLIT_K > 1, the thread block computes partial outputs. The partial
outputs are then atomically summed in the caller code.
Args:
a_ptr: Array of pointers, identifying rows of A
b_ptr: Array of pointers, identifying columns of B
ak_stride: K dimension stride of the A matrix
bk_stride: K dimension stride of the B matrix
K: Length of the K dimension
BLOCK_M: M dimension of the output block m x n
BLOCK_N: N dimension of the output block m x n
BLOCK_K: K dimension atom
EVEN_K: True if the blocks of A and B can be loaded without any
masking.
SPLIT_K: Parameter signifying parallelism in the K dimension.
CAST_TYPE: if True, cast the values from the A matrix to the B
matrix dtype.
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
base_k: Base offset along K dimension for current SPLIT_K group
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Step size along K for each iteration
STEP_K = BLOCK_K * SPLIT_K
# Total number of iterations (compile-time constant)
num_iters = tl.cdiv(K, STEP_K)
for k in range(num_iters):
# Current iteration's global K offset
iter_k = k * STEP_K + base_k
# Check if this iteration is completely valid (no masking needed)
block_end = iter_k + BLOCK_K
if EVEN_K:
# K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Check if we need element-wise masking
if iter_k >= K:
# Entire block out of range, skip
pass
elif block_end <= K:
# Entire block in range, no masking needed (fast path)
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Partial block, need masking (only last iteration)
k_offsets = tl.arange(0, BLOCK_K)
mask = iter_k + k_offsets < K
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
a_ptr += STEP_K * ak_stride
b_ptr += STEP_K * bk_stride
return accumulator
@triton.jit
def do_expand_kernel(
pid_n,
lora_index,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
SAME_STRIDE: tl.constexpr,
SLICE_NUM: tl.constexpr,
EVEN_K: tl.constexpr,
CAST_TYPE: tl.constexpr,
ADD_INPUTS: tl.constexpr,
USE_GDC: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice,
compute the matrix product and store in the appropriate output location.
Given that this is an expand kernel, we don't perform any split-K reduction
as the K dimension is assumed to be small.
"""
# ls_d*_ptr can be either an integer or a pointer
if SAME_STRIDE:
# integer
cur_lora_d0_stride = ls_d0_ptr
cur_lora_d1_stride = ls_d1_ptr
cur_lora_d2_stride = ls_d2_ptr
else:
# pointer
cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id)
cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id)
cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id)
# Identify the input_ptr and lora_ptr from slice_id.
if SLICE_NUM == 1:
cur_input_ptr = input_ptr
cur_lora_ptr = lora_ptr
else:
cur_input_ptr = input_ptr + slice_id * input_d0_stride
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(out_ptr.dtype.element_ty)
)
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K)
a_ptr = (
cur_input_ptr
+ ram[:, None] * input_d1_stride
+ offset_k[None, :] * input_d2_stride
)
b_ptr = (
cur_lora_ptr
+ cur_lora_d0_stride * lora_index
+ offset_k[:, None] * cur_lora_d2_stride
+ rbn[None, :] * cur_lora_d1_stride
)
# Compute the block matrix product.
SPLIT_K = 1
accumulator = mm_k(
a_ptr,
b_ptr,
input_d2_stride,
cur_lora_d2_stride,
offset_k,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
CAST_TYPE,
cur_lora_ptr.dtype.element_ty,
USE_GDC,
base_k=0,
)
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
if SLICE_NUM == 1:
cur_slice_start = slice_start_loc
else:
cur_slice_start = tl.load(slice_start_loc + slice_id)
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start
offset_cm = tl.arange(0, BLOCK_M)
c_ptr = (
out_ptr
+ ram[:, None] * output_d0_stride
+ offset_cn[None, :] * output_d1_stride
)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < (cur_slice_start + N))
if ADD_INPUTS:
tiled_out = tl.load(c_ptr, mask=c_mask)
tiled_c += tiled_out
tl.store(c_ptr, tiled_c, mask=c_mask)
@triton.jit
def do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_index,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
M_LEN,
ram,
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
SLICE_NUM: tl.constexpr,
USE_GDC: tl.constexpr,
):
"""
Given an array of integers that identifies the rows of A, ram,
a lora index that identifies which LoRA to use from lora_ptr, lora_index,
a slice_id that identifies the input/output slice, compute the
matrix product and store in the appropriate output location.
"""
# Identify the lora_ptr from slice_id.
if SLICE_NUM == 1:
# current lora ptr
cur_lora_ptr = lora_ptr
else:
# current lora ptr
cur_lora_ptr = tl.load(lora_ptr + slice_id).to(
tl.pointer_type(input_ptr.dtype.element_ty)
)
# Identify the column indices of B to process.
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
# Identify A and B block pointers
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
a_ptr = (
input_ptr + ram[:, None] * input_d0_stride + offset_k[None, :] * input_d1_stride
)
b_ptr = (
cur_lora_ptr
+ lora_d0_stride * lora_index
+ rbn[None, :] * lora_d1_stride
+ offset_k[:, None] * lora_d2_stride
)
# Compute partial/complete block matrix product.
accumulator = mm_k(
a_ptr,
b_ptr,
input_d1_stride,
lora_d2_stride,
offset_k,
K,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
False,
cur_lora_ptr.dtype.element_ty,
False, # USE_GDC is always False in shrink kernel
base_k=pid_sk * BLOCK_K,
)
# GDC launch dependents hints the runtime system to launch dependent kernels.
if USE_GDC:
tl.extra.cuda.gdc_launch_dependents()
# Identify the C output pointers to store the results of the accumulator.
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
offset_cm = tl.arange(0, BLOCK_M)
cur_out_ptr = out_ptr if SLICE_NUM == 1 else out_ptr + slice_id * output_d0_stride
c_ptr = (
cur_out_ptr
+ ram[:, None] * output_d1_stride
+ offset_cn[None, :] * output_d2_stride
)
c_mask = (offset_cm[:, None] < M_LEN) & (offset_cn[None, :] < N)
accumulator *= scaling
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(c_ptr, accumulator, mask=c_mask)
else:
tl.atomic_add(c_ptr, accumulator, mask=c_mask, sem="relaxed")

View File

@@ -0,0 +1,309 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _lora_expand_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_loc,
input_d0_stride,
input_d1_stride,
input_d2_stride, # 1
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr, # 1
output_d0_stride,
output_d1_stride, # 1
output_hs_ptr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
ADD_INPUTS: tl.constexpr,
CAST_TYPE: tl.constexpr,
SLICE_NUM: tl.constexpr,
SAME_STRIDE: tl.constexpr,
USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_mn = tl.program_id(axis=0)
pid_m = pid_mn % cta_m_num
pid_n = (pid_mn // cta_m_num) % cta_n_num
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# When the output dimensions of each slice are the same,cur_n=N, otherwise
# cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's
# qkv linear.
curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id)
if pid_n * BLOCK_N >= curr_N:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_expand_kernel(
pid_n,
lora_id,
slice_id,
input_ptr,
lora_ptr,
out_ptr,
curr_N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
slice_start_loc,
# input ptr strides
input_d0_stride,
input_d1_stride,
input_d2_stride,
# lora ptr strides
ls_d0_ptr,
ls_d1_ptr,
ls_d2_ptr,
# out ptr strides
output_d0_stride,
output_d1_stride,
# constants
BLOCK_M,
BLOCK_N,
BLOCK_K,
SAME_STRIDE,
SLICE_NUM,
EVEN_K,
CAST_TYPE,
ADD_INPUTS,
USE_GDC,
)
@torch.inference_mode()
def _lora_expand(
inputs: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
lora_b_weights: list[torch.Tensor], # shape [num_lora, hidden_size, lora_rank]
output_tensor: torch.Tensor, # shape [num_tokens, hidden_size * num_slices]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
"""
Args:
inputs (torch.Tensor): input tensor
lora_b_weights (list[torch.Tensor]): lora'b weight
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor.
Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(0) == len(lora_b_weights)
assert output_tensor.is_contiguous()
# metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
(
slice_start_tensor,
lora_ptr_tensor,
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
hidden_sizes_tensor,
same_stride,
MAX_N,
) = _get_lora_b_ptr(lora_b_weights, offset_start, inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank
ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False
NUM_SLICES = len(lora_b_weights)
# Triton kernel configs.
kernel_config = get_lora_op_configs(
op_type="expand",
max_loras=MAX_LORAS,
batch=M,
hidden_size=MAX_N,
rank=K,
num_slices=NUM_SLICES,
add_inputs=add_inputs,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_CTAS = kernel_config["num_ctas"]
NUM_STAGES = kernel_config["num_stages"]
EVEN_K = K % BLOCK_K == 0 # type: ignore
if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [
torch.float16,
torch.bfloat16,
]:
CAST_TYPE = True
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only a few input tokens require
# LoRA. This might not be the best in all cases.
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
num_active_loras,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc = False # supports_pdl(inputs.device)
_lora_expand_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
MAX_N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
slice_start_tensor,
inputs.stride(0),
inputs.stride(1),
inputs.stride(2),
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
output_tensor.stride(0),
output_tensor.stride(1),
hidden_sizes_tensor,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
ADD_INPUTS,
CAST_TYPE,
NUM_SLICES,
same_stride,
use_gdc,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
launch_pdl=use_gdc,
)
return
def _lora_expand_fake(
inputs: torch.Tensor,
lora_b_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_expand",
op_func=_lora_expand,
mutates_args=["output_tensor"],
fake_impl=_lora_expand_fake,
)
lora_expand = torch.ops.vllm.lora_expand
except AttributeError:
lora_expand = _lora_expand

View File

@@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
LoRA kernels metadata preparation utilities.
"""
import bisect
from dataclasses import dataclass, field
import torch
@dataclass
class LoRAKernelMeta:
token_lora_mapping: torch.Tensor
token_indices_sorted_by_lora_ids: torch.Tensor
active_lora_ids: torch.Tensor
num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
# Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up
# to the nearest value in this list to match cudagraph capture keys.
# Empty list means no specialization (use actual count).
captured_lora_counts: list[int] = field(default_factory=list)
@staticmethod
def make(
max_loras: int,
max_num_tokens: int,
device: torch.device | str,
captured_lora_counts: list[int] | None = None,
) -> "LoRAKernelMeta":
token_lora_mapping = torch.empty(
max_num_tokens, dtype=torch.int32, device=device
)
token_indices_sorted_by_lora_ids = torch.empty(
max_num_tokens, dtype=torch.int32, device=device
)
# +1 because "no-lora" is also a possibility
# example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1]
# is a possibility.
active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, device=device)
# using running example, [3, 10, 5, 2] is a possibility.
num_tokens_per_lora = torch.zeros(
max_loras + 1, dtype=torch.int32, device=device
)
# +2 for this because, the first index is always 0.
# using running example, lora_token_start_loc
# is [0, 3, 13, 18, 20].
lora_token_start_loc = torch.zeros(
max_loras + 2, dtype=torch.int32, device=device
)
no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")
return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu,
captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts
else [],
)
def _reset(self):
self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
self.captured_lora_counts = []
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
Prepare kernel metadata tensors for the current forward pass.
Args:
token_lora_mapping (torch.Tensor): Tensor containing lora indices
for each input token.
"""
self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0)
# copy token lora mapping
self.token_lora_mapping[:num_tokens].copy_(
token_lora_mapping, non_blocking=True
)
# token_indices_sorted_by_lora_ids
_, token_indices_sorted_by_lora_ids = torch.sort(
token_lora_mapping, stable=True
)
# start gpu transfer
self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(
token_indices_sorted_by_lora_ids, non_blocking=True
)
# active_lora_ids, num_tokens_per_lora
lora_ids, num_tokens_per_lora = torch.unique(
token_lora_mapping, sorted=True, return_counts=True
)
self.active_lora_ids[: lora_ids.size(0)].copy_(lora_ids, non_blocking=True)
self.num_tokens_per_lora[: num_tokens_per_lora.size(0)].copy_(
num_tokens_per_lora, non_blocking=True
)
self.num_active_loras = lora_ids.size(0)
# Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph.
if self.captured_lora_counts and self.num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
if idx < len(self.captured_lora_counts):
self.num_active_loras = self.captured_lora_counts[idx]
# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_(
lora_token_start_loc, non_blocking=True
)
def meta_args(
self,
token_nums: int,
specialize_active_lora: bool,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
]:
"""
This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the
metadata required by the kernel, in order, as a tuple, so it can be
unpacked directly during the lora_shrink/lora_expand function call.
Args:
token_nums (int): Number of input tokens in the current forward
pass of the kernel.
"""
max_loras = self.active_lora_ids.size(0) - 1
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
self.num_active_loras if specialize_active_lora else max_loras + 1,
)

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
import torch
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
@triton.jit
def _lora_shrink_kernel(
input_ptr,
lora_ptr,
out_ptr,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
input_d0_stride,
input_d1_stride,
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
output_d0_stride,
output_d1_stride,
output_d2_stride,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SLICE_NUM: tl.constexpr,
USE_GDC: tl.constexpr,
launch_pdl: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)
pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m_n = pid_sk_m_n // SPLIT_K
num_pid_in_group = GROUP_SIZE_M * cta_n_num
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)
# Column-major ordering within groups for better cache reuse
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
if lora_id == -1:
# Early exit for the no-lora case.
return
lora_m_size = tl.load(num_tokens_per_lora + lora_idx)
cta_m_offset = pid_m * BLOCK_M
if cta_m_offset >= lora_m_size:
# Early exit CTA.
return
# num rows this CTA should process.
cta_m_len = min(BLOCK_M, lora_m_size - cta_m_offset)
# Identify all rows that this CTA should process.
lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx)
cta_lora_seq_indices = (
token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset
)
# Load all relevant row indices.
offset_m = tl.arange(0, BLOCK_M) % cta_m_len
ram = tl.load(cta_lora_seq_indices + offset_m)
do_shrink_kernel(
pid_n,
pid_sk,
slice_id,
lora_id,
input_ptr,
lora_ptr,
out_ptr,
N,
K,
cta_m_len,
ram, # array identifying the rows of Input ptr to operate on
# input strides
input_d0_stride,
input_d1_stride,
# lora strides
lora_d0_stride,
lora_d1_stride,
lora_d2_stride,
# output strides
output_d0_stride,
output_d1_stride,
output_d2_stride,
scaling,
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
SLICE_NUM,
USE_GDC,
)
@torch.inference_mode()
def _lora_shrink(
inputs: torch.Tensor, # shape [num_tokens, hidden_size]
lora_a_weights: list[torch.Tensor], # shape [num_loras, lora_rank, hidden_size]
output_tensor: torch.Tensor, # shape [num_slices, num_tokens, lora_rank]
token_lora_mapping: torch.Tensor, # shape [num_tokens]
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
scaling: float,
) -> None:
"""
Args:
inputs (torch.Tensor): Input tensor
lora_a_weights (list[torch.Tensor]): LoRA weights
output_tensor (torch.Tensor): output tensor
token_lora_mapping (torch.Tensor): A tensor mapping each input token
to the lora-id related to that token. A value of -1 indicates that
LoRA doesn't apply to that token.
token_indices_sorted_by_lora_ids (torch.Tensor): Row/Token indices from
the A matrix grouped by LoRA IDs.
num_tokens_per_lora (torch.Tensor): num_tokens_per_lora[i] is the number
of tokens that are to be processed by LoRA ID lora_ids[i]
lora_token_start_loc (torch.Tensor): A cumulative sum of
num_tokens_per_lora. lora_token_start_loc[0] is always 0 so that
lora_token_start_loc[i], along with num_tokens_per_lora[i]
identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor.
"""
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights:
assert weight.dtype in [torch.float16, torch.bfloat16]
assert inputs.size(1) == lora_a_weights[0].size(-1)
assert inputs.is_contiguous()
assert output_tensor.is_contiguous()
# metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0)
assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1
output_tensor.zero_()
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, lora_strides_d2) = (
_get_lora_a_ptr(lora_a_weights, inputs.device)
)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0)
# Triton kernel configs
kernel_config = get_lora_op_configs(
"shrink",
max_loras=MAX_LORAS,
batch=M,
hidden_size=K,
rank=N,
num_slices=NUM_SLICES,
)
BLOCK_M = kernel_config["block_m"]
BLOCK_N = kernel_config["block_n"]
BLOCK_K = kernel_config["block_k"]
SPLIT_K = kernel_config["split_k"]
NUM_WARPS = kernel_config["num_warps"]
NUM_STAGES = kernel_config["num_stages"]
NUM_CTAS = kernel_config["num_ctas"]
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
# TODO (varun): This grid formulation maximizes parallelization at the
# cost of wasteful thread block launch when only few of the input tokens
# require LoRA. This might not be the best in all cases.
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
num_active_loras,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
use_gdc = False # supports_pdl(inputs.device)
_lora_shrink_kernel[grid](
inputs,
lora_ptr_tensor,
output_tensor,
M,
N,
K,
token_indices_sorted_by_lora_ids,
num_tokens_per_lora,
lora_token_start_loc,
lora_ids,
scaling,
inputs.stride(0),
inputs.stride(1),
lora_strides_d0,
lora_strides_d1,
lora_strides_d2,
output_tensor.stride(0),
output_tensor.stride(1),
output_tensor.stride(2),
BLOCK_M,
BLOCK_N,
BLOCK_K,
EVEN_K,
SPLIT_K,
GROUP_SIZE_M,
NUM_SLICES,
use_gdc,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
num_stages=NUM_STAGES,
launch_pdl=use_gdc,
)
return
def _lora_shrink_fake(
inputs: torch.Tensor,
lora_a_weights: list[torch.Tensor],
output_tensor: torch.Tensor,
token_lora_mapping: torch.Tensor,
token_indices_sorted_by_lora_ids: torch.Tensor,
num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
scaling: float,
) -> None:
return
try:
direct_register_custom_op(
op_name="lora_shrink",
op_func=_lora_shrink,
mutates_args=["output_tensor"],
fake_impl=_lora_shrink_fake,
)
lora_shrink = torch.ops.vllm.lora_shrink
except AttributeError:
lora_shrink = _lora_shrink

View File

@@ -0,0 +1,318 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import json
from functools import lru_cache
from pathlib import Path
from typing import Any
import torch
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2
logger = init_logger(__name__)
is_batch_invariant = vllm_is_batch_invariant()
_LORA_A_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
_LORA_B_PTR_DICT: dict[tuple[int, ...], tuple[torch.tensor, ...]] = {}
def _get_lora_a_ptr(lora_a_weights: list[torch.Tensor], device: torch.device):
"""
`_LORA_A_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights)
if values := _LORA_A_PTR_DICT.get(key):
return values
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
tensor_ptrs = []
for lora_a_weight in lora_a_weights:
if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_a_weight.size(1) == 1
lora_a_weight = lora_a_weight.squeeze(dim=1)
else:
assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_a_weight.is_contiguous()
tensor_ptrs.append(lora_a_weight.data_ptr())
lora_strides_d0.append(lora_a_weight.stride(0))
lora_strides_d1.append(lora_a_weight.stride(1))
lora_strides_d2.append(lora_a_weight.stride(2))
if len(lora_a_weights) > 1:
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
else:
lora_ptr_tensor = lora_a_weights[0]
if (
len(set(lora_strides_d0)) > 1
or len(set(lora_strides_d1)) > 1
or len(set(lora_strides_d2)) > 1
):
raise ValueError("All LoRA weights must have the same stride.")
_LORA_A_PTR_DICT[key] = (
lora_ptr_tensor,
lora_strides_d0[0],
lora_strides_d1[0],
lora_strides_d2[0],
)
return _LORA_A_PTR_DICT.get(key)
def _get_lora_b_ptr(
lora_weights: list[torch.Tensor], offset_start: int, device: torch.device
):
"""
`_LORA_B_PTR_DICT` collects the required information during `profile_run`,
After this, it remains constant and subsequent usage is through LUT.
Refer to:
https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py
"""
key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights)
if values := _LORA_B_PTR_DICT.get(key):
return values
slice_offset_lst = []
tensor_ptrs = []
lora_strides_d0 = []
lora_strides_d1 = []
lora_strides_d2 = []
hidden_sizes = []
slice_offset = offset_start
for lora_b_weight in lora_weights:
if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank)
assert lora_b_weight.size(1) == 1
lora_b_weight = lora_b_weight.squeeze(dim=1)
else:
assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank)
assert lora_b_weight.is_contiguous()
tensor_ptrs.append(lora_b_weight.data_ptr())
lora_strides_d0.append(lora_b_weight.stride(0))
lora_strides_d1.append(lora_b_weight.stride(1))
lora_strides_d2.append(lora_b_weight.stride(2))
slice_offset_lst.append(slice_offset)
slice_offset += lora_b_weight.size(1)
hidden_sizes.append(lora_b_weight.size(1))
if len(lora_weights) > 1:
# note these are device tensors
lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64)
slice_start_tensor = torch.tensor(
slice_offset_lst, device=device, dtype=torch.uint64
)
else:
slice_start_tensor = slice_offset_lst[0]
lora_ptr_tensor = lora_b_weight[0]
# If each lora has the same stride, there's no need to use a
# tensor for storage.
if (
len(set(lora_strides_d0)) == 1
and len(set(lora_strides_d1)) == 1
and len(set(lora_strides_d2)) == 1
) and len(set(hidden_sizes)) == 1:
lora_strides_d0_tensor = lora_strides_d0[0]
lora_strides_d1_tensor = lora_strides_d1[0]
lora_strides_d2_tensor = lora_strides_d2[0]
hidden_sizes_tensor = hidden_sizes[0]
same_stride = True
else:
lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device)
lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device)
lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device)
hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device)
same_stride = False
# MAX_N is the maximum hidden size among all the lora_b weights
MAX_N = max(hidden_sizes)
_LORA_B_PTR_DICT[key] = (
slice_start_tensor,
lora_ptr_tensor,
lora_strides_d0_tensor,
lora_strides_d1_tensor,
lora_strides_d2_tensor,
hidden_sizes_tensor,
same_stride,
MAX_N,
)
return _LORA_B_PTR_DICT.get(key)
@functools.lru_cache
def load_lora_op_config(op_type: str, add_inputs: bool | None) -> dict | None:
user_defined_config_folder = envs.VLLM_TUNED_CONFIG_FOLDER
# Avoid optimizing for the batch invariant case. Use default config
if user_defined_config_folder is not None and not is_batch_invariant:
gpu_name = torch.cuda.get_device_name()
gpu_name = gpu_name.replace(" ", "_")
gpu_name = gpu_name.replace("-", "_")
config_fname = None
# only expand op needs to consider add_inputs
if op_type == "expand":
config_fname = (
f"{gpu_name}_{op_type.upper()}_{str(add_inputs).upper()}.json"
)
else:
config_fname = f"{gpu_name}_{op_type.upper()}.json"
config_path = Path(f"{user_defined_config_folder}/{config_fname}")
if not config_path.exists():
logger.warning_once(f"No LoRA kernel configs found in {config_path}")
return None
# Load json
logger.info_once(f"Using tuned LoRA kernel configs from {config_path}.")
with open(str(config_path)) as f:
config_data = json.load(f)
else:
config_data = None
return config_data
@functools.lru_cache
def get_lora_op_configs(
op_type: str,
max_loras: int,
batch: int,
hidden_size: int,
rank: int,
num_slices: int,
add_inputs: bool | None = None,
moe_intermediate_size: int | None = None,
) -> dict[str, int | None]:
# Add support for fused_moe_lora ops
assert op_type in [
"shrink",
"expand",
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_shrink",
"fused_moe_lora_w2_expand",
]
# default config
default = {}
if op_type == "shrink":
split_k = 64 if batch < 128 else 8
if is_batch_invariant:
split_k = 1
default = {
"block_m": 32,
"block_n": 16,
"block_k": 256 if batch < 128 else 32,
"split_k": split_k,
"num_warps": 4,
"num_ctas": 1,
"group_size_m": 8,
"num_stages": 2,
"max_nreg": None,
}
# The default config for fused_moe_lora ops
elif op_type in [
"fused_moe_lora_w13_shrink",
"fused_moe_lora_w2_shrink",
]:
default = {
"block_m": 64,
"block_n": min(64, next_power_of_2(rank)),
"block_k": 32,
"num_warps": 4,
"num_stages": 3,
"group_size_m": 8,
"split_k": 1,
}
elif op_type in [
"fused_moe_lora_w13_expand",
"fused_moe_lora_w2_expand",
]:
default = {
"block_m": 64,
"block_n": 64,
"block_k": max(16, min(32, next_power_of_2(rank))),
"num_warps": 4,
"num_stages": 3,
"group_size_m": 8,
"split_k": 1,
}
else:
default = {
"block_m": 64,
"block_n": 64 if num_slices > 1 else 128,
"block_k": 16,
"num_warps": 4,
"num_ctas": 1,
"num_stages": 2,
"max_nreg": None,
}
m = batch
k, n = (hidden_size, rank) if op_type == "shrink" else (rank, hidden_size)
config_data: Any
config_data = load_lora_op_config(op_type, add_inputs)
if not config_data:
logger.warning_once("Using default LoRA kernel configs")
return default
# config is structured as config_data[max_loras][num_slices][m][k][n] = {}
# slice by max_loras
config_data = (
config_data.get(str(max_loras))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - max_loras))]
)
# slice by num_slices
config_data = config_data[str(num_slices)]
# slice by m
config_data = (
config_data.get(str(m))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - m))]
)
# slice by k
config_data = (
config_data.get(str(k))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - k))]
)
# slice by n
config_data = (
config_data.get(str(n))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - n))]
)
# slice by moe-intermediate-size if applicable
if moe_intermediate_size is not None:
i = moe_intermediate_size
config_data = (
config_data.get(str(i))
or config_data[min(config_data.keys(), key=lambda x: abs(int(x) - i))]
)
assert config_data is not None
return config_data
@lru_cache
def supports_pdl(device: torch.device | None = None) -> bool:
"""
Refer to: https://github.com/triton-lang/triton/blob/v3.5.0/python/tutorials/11-programmatic-dependent-launch.py
"""
# PDL requires compute capability SM90 or above
return (
current_platform.is_cuda()
and current_platform.has_device_capability(90)
and not envs.VLLM_LORA_DISABLE_PDL
)

View File

@@ -0,0 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.xpu_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]

View File

@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
def bgmv_shrink(
inputs: torch.Tensor,
lora_a_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
) -> None:
torch.ops._xpu_C.bgmv_shrink(
output_tensor, inputs, lora_a_weights, lora_indices_tensor, scaling
)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
) -> None:
torch.ops._xpu_C.bgmv_expand(
output_tensor, inputs, lora_b_weights, lora_indices_tensor, add_inputs
)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
) -> None:
assert slice_size == lora_b_weights.size(-2)
assert slice_offset + slice_size <= output_tensor.size(1)
torch.ops._xpu_C.bgmv_expand_slice(
output_tensor,
inputs,
lora_b_weights,
lora_indices_tensor,
slice_offset,
slice_size,
add_inputs,
)

128
vllm/lora/peft_helper.py Normal file
View File

@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from: https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora/config.py
import json
import math
import os
from dataclasses import MISSING, dataclass, field, fields
from typing import Literal
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
logger = init_logger(__name__)
@dataclass
class PEFTHelper:
"""
A helper class for PEFT configurations, specifically designed for LoRA.
This class handles configuration validation, compatibility checks for
various LoRA implementations.
"""
# Required fields
r: int
lora_alpha: int
target_modules: list[str] | str
bias: Literal["none"] = field(default="none")
modules_to_save: list[str] | None = field(default=None)
# True to use Rank-Stabilized LoRA (rsLoRA, see: https://arxiv.org/abs/2312.03732)
use_rslora: bool = field(default=False)
# True to use Weight-Decomposed Low-Rank Adaptation (DoRA, see: https://arxiv.org/abs/2402.09353)
use_dora: bool = field(default=False)
# Extra vllm field, start with 'vllm_' to avoid conflict
vllm_lora_scaling_factor: float = field(default=1.0)
vllm_max_position_embeddings: int | None = field(default=False)
def _validate_features(self) -> list[str]:
"""
Check if there are any unsupported LoRA features.
"""
error_msg = []
if self.modules_to_save:
error_msg.append("vLLM only supports modules_to_save being None.")
if self.use_dora:
error_msg.append("vLLM does not yet support DoRA.")
return error_msg
def __post_init__(self):
if self.use_rslora:
logger.info_once("Loading LoRA weights trained with rsLoRA.")
self.vllm_lora_scaling_factor = self.lora_alpha / math.sqrt(self.r)
else:
self.vllm_lora_scaling_factor = self.lora_alpha / self.r
@classmethod
def from_dict(cls, config_dict: dict) -> "PEFTHelper":
# Get all field information from the class
class_fields = {f.name: f for f in fields(cls)}
# Check for required fields
required_fields = {
name
for name, f in class_fields.items()
if f.default is MISSING and f.default_factory is MISSING
}
# Identify any missing required fields
missing_fields = required_fields - set(config_dict.keys())
if missing_fields:
raise ValueError(f"Missing required configuration fields: {missing_fields}")
# Filter out fields that aren't defined in the class
filtered_dict = {k: v for k, v in config_dict.items() if k in class_fields}
return cls(**filtered_dict)
@classmethod
def from_local_dir(
cls,
lora_path: str,
max_position_embeddings: int | None,
tensorizer_config_dict: dict | None = None,
) -> "PEFTHelper":
lora_config_path = os.path.join(lora_path, "adapter_config.json")
if tensorizer_config_dict:
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
from tensorizer.stream_io import open_stream
lora_config_path = os.path.join(
tensorizer_config.tensorizer_dir, "adapter_config.json"
)
with open_stream(
lora_config_path, mode="rb", **tensorizer_args.stream_kwargs
) as f:
config = json.load(f)
logger.info(
"Successfully deserialized LoRA config from %s",
tensorizer_config.tensorizer_dir,
)
else:
with open(lora_config_path) as f:
config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings
return cls.from_dict(config)
def validate_legal(self, lora_config: LoRAConfig) -> None:
"""
Validates the LoRA configuration settings against application
constraints and requirements.
"""
error_msg = self._validate_features()
if self.r > lora_config.max_lora_rank:
error_msg.append(
f"LoRA rank {self.r} is greater than max_lora_rank"
f" {lora_config.max_lora_rank}."
)
if self.bias != "none":
error_msg.append("Adapter bias is not supported.")
if error_msg:
raise ValueError(f"{' '.join(error_msg)}")

View File

@@ -0,0 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
from vllm.lora.punica_wrapper.punica_selector import get_punica_wrapper
__all__ = [
"PunicaWrapperBase",
"get_punica_wrapper",
]

View File

@@ -0,0 +1,495 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
import torch
from .utils import compute_meta, convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
class PunicaWrapperABC(ABC):
"""
PunicaWrapper ABC.
"""
@abstractmethod
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
**kwargs,
) -> None:
"""
Update the lora-related metadata
"""
raise NotImplementedError
@abstractmethod
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
"""
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_b.
"""
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA,
and this layer only requires the expand operation.
"""
raise NotImplementedError
@abstractmethod
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applicable to linear-related lora.
"""
raise NotImplementedError
@abstractmethod
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
"""
raise NotImplementedError
class PunicaWrapperBase(PunicaWrapperABC):
"""
PunicaWrapperBase is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
self._token_lora_indices = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._sampler_indices = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._sampler_indices_padded = torch.empty(
max_num_batched_tokens, dtype=torch.long, device=device
)
self._embeddings_indices = torch.empty(
2, max_num_batched_tokens, dtype=torch.long, device=device
)
# 4 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
self.indices_len: list[int | None] = [None] * 4
# these attributes are the information required for sgmv kernel
self._seq_start_locs = torch.empty(max_batches, dtype=torch.long, device=device)
self._seq_lengths = torch.empty(max_batches, dtype=torch.long, device=device)
self._lora_indices_per_batch = torch.empty(
max_batches, dtype=torch.long, device=device
)
self.device: torch.device = device
self.max_length: int = 0
self.token_nums: int = 0
self.batch_size: int = -1
self.is_prefill = False
self.no_lora = False
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
):
# NOTE We have remove lora extra vocab support for now. So we set
# extra_vocab_size always to 0, and extra_vocab_size will be removed.
extra_vocab_size = 0
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
extra_vocab_size,
self.device,
)
self._token_lora_indices[: base_indices.shape[0]].copy_(base_indices)
self._sampler_indices[: sampler_indices.shape[0]].copy_(sampler_indices)
self._sampler_indices_padded[: sampler_indices_padded.shape[0]].copy_(
sampler_indices_padded
)
self._embeddings_indices[
: embeddings_indices.shape[0], : embeddings_indices.shape[1]
].copy_(embeddings_indices)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
(
b_seq_start_tensor,
seq_length_tensor,
lora_indices_tensor,
batch_size,
max_length,
token_nums,
no_lora,
) = compute_meta(token_lora_tensor)
self._seq_start_locs[: b_seq_start_tensor.shape[0]].copy_(b_seq_start_tensor)
self._seq_lengths[: seq_length_tensor.shape[0]].copy_(seq_length_tensor)
self._lora_indices_per_batch[: lora_indices_tensor.shape[0]].copy_(
lora_indices_tensor
)
self.batch_size = batch_size
self.max_length = max_length
self.token_nums = token_nums
self.no_lora = no_lora
@property
def prefill_metadata(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
"""
This property provides a convenient way to access the necessary
metadata for prefill-related kernel computations.
1. seq_start_locs: Tensor of sequence start positions.
2. seq_lengths: Tensor of sequence lengths.
3. lora_indices_per_batch: Tensor of lora indices, and an index of
-1 means no lora should be applied.
4. batch_size: Batch size after clustering identical lora indices.
5. max_length: The maximum sequence length in the batch.
6. token_nums: The token numbers in the batch.
"""
return (
self._seq_start_locs[: self.batch_size],
self._seq_lengths[: self.batch_size],
self._lora_indices_per_batch[: self.batch_size],
self.batch_size,
self.max_length,
self.token_nums,
)
@property
def token_lora_indices(self) -> torch.Tensor:
"""
This property provides the lora indices corresponding to each token
in the batch. An index of -1 means no lora should be applied.
"""
token_lora_len = self.indices_len[0]
return self._token_lora_indices[:token_lora_len]
@property
def sampler_indices(self) -> torch.Tensor:
"""
This property is used to access the lora indices specifically for
LogitsProcessorWithLoRA.
"""
sampler_indices_len = self.indices_len[1]
return self._sampler_indices[:sampler_indices_len]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
indices_padded_len = self.indices_len[2]
return self._sampler_indices_padded[:indices_padded_len]
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
embeddings_indices_len = self.indices_len[3]
return self._embeddings_indices[:, :embeddings_indices_len]
def update_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
**kwargs,
):
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
if mapping.is_prefill:
# Update metadata required for prefill-related operators.
self._update_prefill_metadata(self.token_lora_indices)
self.is_prefill = True
else:
self.is_prefill = False
@abstractmethod
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
offset = offset_start
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
offset_start (int): The starting position of y, defaults to 0
add_inputs (bool): Defaults to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
and this layer only requires the expand operation.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
@abstractmethod
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor | None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
naive_block_assignment: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
# TODO: implement it based on torch ops
raise NotImplementedError
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
max_lora_rank: int,
top_k_num: int,
shrink_config,
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
token_lora_mapping: torch.Tensor | None = None,
):
"""
Performs a fused forward computation for LoRA of
Mixture-of-Experts (MoE) layer.
"""
# TODO: implement it based on torch ops
raise NotImplementedError

View File

@@ -0,0 +1,351 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.lora.ops.torch_ops import (
bgmv_expand,
bgmv_expand_slice,
bgmv_shrink,
sgmv_expand,
sgmv_expand_slice,
sgmv_shrink,
)
from .punica_base import PunicaWrapperBase
# The platforms that are compatible with the PyTorch-native implementation can
# inherit this class
class PunicaWrapperCPU(PunicaWrapperBase):
"""
PunicaWrapperCPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
def _shrink_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
# No LoRA request, so return directly
if self.no_lora:
return
sgmv_shrink(
x,
w_t_all,
y,
*self.prefill_metadata,
scale,
)
def _shrink_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
def _expand_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
# No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand(
x,
w_t_all,
y,
*self.prefill_metadata,
add_inputs,
)
def _expand_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
add_inputs: bool,
):
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs)
def _expand_slice_prefill(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
# No LoRA request, so return directly
if self.no_lora:
return
sgmv_expand_slice(
x,
w_t_all,
y,
*self.prefill_metadata,
y_offset,
y_slice_size,
add_inputs,
)
def _expand_slice_decode(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
bgmv_expand_slice(
x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs
)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool = True,
):
"""
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
computation, which is suitable for the
GEMM of lora'b.
"""
expand_slice_fun: Callable = (
self._expand_slice_prefill if self.is_prefill else self._expand_slice_decode
)
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs)
def _apply_shrink(
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, scale: float
):
"""
Perform the ` y+=x@w_t_all` computation, which is suitable for the
GEMM of lora'a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
"""
y_org = y
y = y.view(-1, y.shape[-1])
shrink_fun: Callable = (
self._shrink_prefill if self.is_prefill else self._shrink_decode
)
shrink_fun(y, x, w_t_all, scale)
y = y.view_as(y_org)
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
When `is_prefill is` true, it indicates that it is currently the
prefill stage, and the `_shrink_prefill` function should be called.
Otherwise, it is the decode stage, and the _shrink_decode function
should be called.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
# TODO fuse these kernels
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = offset_start
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
y = y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only need expand op
expand_fun: Callable = (
self._expand_prefill if self.is_prefill else self._expand_decode
)
expand_fun(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = tuple(
torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
for _ in range(len(output_slices))
)
self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
self.add_expand(
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
# We set the buffer to be float32 by default, consistent with the
# triton op
buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device)
# LogitsProcessorWithLoRA always using bgmv.
bgmv_shrink(x, lora_a_stacked, buffer, self.sampler_indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, self.sampler_indices, add_inputs=True)
y = y.view_as(y_org)

View File

@@ -0,0 +1,457 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import final
import torch
from vllm.lora.layers import LoRAMapping
from vllm.lora.utils import get_captured_lora_counts
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import round_up
if HAS_TRITON:
from vllm.lora.ops.triton_ops import (
LoRAKernelMeta,
fused_moe_lora,
lora_expand,
lora_shrink,
)
from vllm import _custom_ops as ops
from .punica_base import PunicaWrapperBase
@final
class PunicaWrapperGPU(PunicaWrapperBase):
"""
PunicaWrapperGPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica triton kernel.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
self.lora_config = kwargs["lora_config"]
self.max_loras = self.lora_config.max_loras
# Compute captured LoRA counts for cudagraph specialization.
captured_lora_counts = get_captured_lora_counts(
self.max_loras, self.lora_config.specialize_active_lora
)
self.token_mapping_meta = LoRAKernelMeta.make(
self.max_loras,
max_num_batched_tokens,
device=device,
captured_lora_counts=captured_lora_counts,
)
# When speculative decoding is enabled, max_num_samples is
# max_batches * (num_speculative_decoding_tokens + 1).
# This line can be optimized by replacing max_num_batched_tokens
# to max_batches * (num_speculative_decoding_tokens + 1).
self.prompt_mapping_meta = LoRAKernelMeta.make(
self.max_loras,
max_num_batched_tokens,
device=device,
captured_lora_counts=captured_lora_counts,
)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
# Prepare cuda kernel metadata tensors
self.token_mapping_meta.prepare_tensors(self.token_lora_indices)
self.prompt_mapping_meta.prepare_tensors(self.sampler_indices)
def add_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
lora_shrink(
x,
lora_a_stacked,
y,
*self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
scale,
)
def add_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
assert x.ndim == 3
assert x.size(0) == len(output_slices)
num_tokens = x.size(1) # first dimension is the num slices
lora_expand(
x,
lora_b_stacked,
y,
*self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
),
offset_start=offset_start,
add_inputs=True,
)
y = y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
lora_expand(
x.unsqueeze(dim=0),
(lora_b_stacked,),
y,
*self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
offset_start=0,
add_inputs=add_inputs,
)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
r = lora_b_stacked[0].size(-1)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty(
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
)
self.add_shrink(
buffer, # type: ignore
x,
lora_a_stacked,
scale,
**kwargs,
)
self.add_expand(
y,
buffer, # type: ignore
lora_b_stacked,
output_slices,
add_inputs=True,
**kwargs,
)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]): Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
assert buffer is None, (
"To minimize overhead, the buffer should be created by "
".add_lora_linear() instead of being passed in."
)
# We set the buffer to be float32 by default, refer to:
# https://github.com/triton-lang/triton/issues/1387
# Note: buffer is zeroed inside the shrink op
buffer = torch.empty((x.size(0), r), dtype=torch.float32, device=x.device)
lora_shrink(
x,
[lora_a_stacked],
buffer.unsqueeze(dim=0),
*self.prompt_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
),
scale,
)
lora_expand(
buffer.unsqueeze(dim=0),
[lora_b_stacked],
y,
*self.prompt_mapping_meta.meta_args(
buffer.size(0), self.lora_config.specialize_active_lora
),
add_inputs=True,
)
y = y.view_as(y_org)
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
naive_block_assignment: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
)
)
if naive_block_assignment:
expert_ids = topk_ids.reshape(-1)
sorted_ids = None
num_tokens_post_pad = None
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return None, sorted_ids, expert_ids, num_tokens_post_pad
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
max_lora_rank: int,
top_k_num: int,
shrink_config,
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
token_lora_mapping: torch.Tensor | None = None,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(
token_lora_mapping_meta,
_,
_,
_,
lora_ids,
_,
num_active_loras,
) = self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
)
if token_lora_mapping is None:
token_lora_mapping = token_lora_mapping_meta
fused_moe_lora(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64),
shrink_config.get("BLOCK_SIZE_K", 32),
shrink_config.get("GROUP_SIZE_M", 8),
shrink_config.get("NUM_WARPS", 4),
shrink_config.get("NUM_STAGES", 3),
shrink_config.get("SPLIT_K", 1),
expand_config.get("BLOCK_SIZE_M", 64),
expand_config.get("BLOCK_SIZE_N", 64),
expand_config.get("BLOCK_SIZE_K", 32),
expand_config.get("GROUP_SIZE_M", 8),
expand_config.get("NUM_WARPS", 4),
expand_config.get("NUM_STAGES", 3),
expand_config.get("SPLIT_K", 1),
mul_routed_weight,
fully_sharded,
offset,
)

View File

@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from .punica_base import PunicaWrapperBase
logger = init_logger(__name__)
def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
punica_wrapper_qualname = current_platform.get_punica_wrapper()
punica_wrapper_cls = resolve_obj_by_qualname(punica_wrapper_qualname)
punica_wrapper = punica_wrapper_cls(*args, **kwargs)
assert punica_wrapper is not None, (
"the punica_wrapper_qualname(" + punica_wrapper_qualname + ") is wrong."
)
logger.info_once("Using %s.", punica_wrapper_qualname.rsplit(".", 1)[1])
return punica_wrapper

View File

@@ -0,0 +1,421 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Based on:
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
Punica: Multi-Tenant LoRA Serving.
https://arxiv.org/abs/2310.18547
"""
from typing import final
import torch
from vllm import _custom_ops as ops
from vllm.lora.layers import LoRAMapping
from vllm.lora.ops.xpu_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import round_up
if HAS_TRITON:
from vllm.lora.ops.triton_ops import (
LoRAKernelMeta,
fused_moe_lora,
)
from .punica_base import PunicaWrapperBase
@final
class PunicaWrapperXPU(PunicaWrapperBase):
"""
PunicaWrapperXPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the punica ipex kernel.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
self.lora_config = kwargs["lora_config"]
self.max_loras = self.lora_config.max_loras
self.token_mapping_meta = LoRAKernelMeta.make(
self.max_loras, max_num_batched_tokens, device=device
)
def update_metadata(
self,
mapping: LoRAMapping,
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
**kwargs,
):
self.is_prefill = mapping.is_prefill
self._update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
def _apply_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
bgmv_shrink(x, w_t_all, y, self._get_token_lora_indices(x), scale)
def _apply_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
):
token_lora_indices = self._get_token_lora_indices(x)
bgmv_expand_slice(
x, w_t_all, y, token_lora_indices, y_offset, y_slice_size, add_inputs
)
def add_shrink(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
):
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (torch.Tensor): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], scale)
def add_expand(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> None:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
assert x.ndim == 3
assert x.size(0) == len(output_slices)
# TODO fuse these kernels
for slice_idx in range(len(lora_b_stacked)):
self._apply_expand(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_start,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_start += output_slices[slice_idx]
y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> None:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
token_lora_indices = self._get_token_lora_indices(x)
bgmv_expand(x, lora_b_stacked, y, token_lora_indices, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will be changed in-place.
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[torch.Tensor]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
buffer = torch.zeros( # type: ignore
(len(output_slices), x.size(0), r),
dtype=x.dtype,
device=x.device,
)
self.add_shrink(
buffer, # type: ignore
x,
lora_a_stacked,
scale,
**kwargs,
)
self.add_expand(
y,
buffer, # type: ignore
lora_b_stacked,
output_slices,
add_inputs=True,
**kwargs,
)
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> None:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor): lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]): Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
r = lora_b_stacked.size(-1)
if buffer is None:
buffer = torch.zeros((x.size(0), r), dtype=x.dtype, device=x.device)
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
bgmv_shrink(x, lora_a_stacked, buffer, sampler_indices, scale)
bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org)
def moe_lora_align_block_size(
self,
topk_ids: torch.Tensor,
num_tokens: int,
block_size: int,
num_experts: int,
max_loras: int,
adapter_enabled: torch.Tensor,
expert_map: torch.Tensor | None = None,
pad_sorted_ids: bool = False,
naive_block_assignment: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns tokens and experts into block-sized chunks for LoRA-based
mixture-of-experts (MoE) execution.
"""
(token_lora_mapping, _, _, _, lora_ids, _, _) = (
self.token_mapping_meta.meta_args(
num_tokens, self.lora_config.specialize_active_lora
)
)
if naive_block_assignment:
expert_ids = topk_ids.reshape(-1)
sorted_ids = None
num_tokens_post_pad = None
else:
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
if pad_sorted_ids:
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
sorted_ids = torch.empty(
(max_loras * max_num_tokens_padded,),
dtype=torch.int32,
device=topk_ids.device,
)
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be set default to -1 to prevent a blank block
expert_ids = torch.empty(
(max_loras * max_num_m_blocks,),
dtype=torch.int32,
device=topk_ids.device,
)
num_tokens_post_pad = torch.empty(
(max_loras), dtype=torch.int32, device=topk_ids.device
)
ops.moe_lora_align_block_size(
topk_ids,
token_lora_mapping,
num_experts,
block_size,
max_loras,
max_num_tokens_padded,
max_num_m_blocks,
sorted_ids,
expert_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)
if expert_map is not None:
expert_ids = expert_map[expert_ids]
return None, sorted_ids, expert_ids, num_tokens_post_pad
def add_lora_fused_moe(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor | None,
max_lora_rank: int,
top_k_num: int,
shrink_config,
expand_config,
adapter_enabled: torch.Tensor,
mul_routed_weight=False,
fully_sharded: bool = False,
offset: int = 0,
token_lora_mapping: torch.Tensor | None = None,
):
"""
Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer.
"""
(
token_lora_mapping_meta,
_,
_,
_,
lora_ids,
_,
num_active_loras,
) = self.token_mapping_meta.meta_args(
x.size(0), self.lora_config.specialize_active_lora
)
if token_lora_mapping is None:
token_lora_mapping = token_lora_mapping_meta
fused_moe_lora(
y,
x,
lora_a_stacked,
lora_b_stacked,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
token_lora_mapping,
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
shrink_config.get("BLOCK_SIZE_M", 64),
shrink_config.get("BLOCK_SIZE_N", 64),
shrink_config.get("BLOCK_SIZE_K", 32),
shrink_config.get("GROUP_SIZE_M", 8),
shrink_config.get("NUM_WARPS", 4),
shrink_config.get("NUM_STAGES", 3),
shrink_config.get("SPLIT_K", 1),
expand_config.get("BLOCK_SIZE_M", 64),
expand_config.get("BLOCK_SIZE_N", 64),
expand_config.get("BLOCK_SIZE_K", 32),
expand_config.get("GROUP_SIZE_M", 8),
expand_config.get("NUM_WARPS", 4),
expand_config.get("NUM_STAGES", 3),
expand_config.get("SPLIT_K", 1),
mul_routed_weight,
fully_sharded,
offset,
)

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
import torch
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
def compute_meta(
token_lora_tensor: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
"""
Get the information required for the sgmv kernel. With the features:
1. If consecutive requests in the batch use the same LoRA, this function
will combine them into a single request, improving sgmv kernel inference
performance.
2. At the beginning of each prefill stage inference, recalculations are
needed based on the input, but only once.
"""
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
token_lora_tensor, return_counts=True
)
cum_result = torch.cumsum(seq_length_tensor, dim=0)
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
b_seq_start_tensor[1:].copy_(cum_result[:-1])
max_length = seq_length_tensor.max().item()
token_nums = seq_length_tensor.sum().item()
batch_size = lora_indices_tensor.size(0)
no_lora = False
# -1 means no lora should be applied. Use `no_lora` to determine whether
# the current step requires LoRA. If LoRA is not needed, the prefill stage
# does not need to launch the triton kernel, which can improve performance
if batch_size == 1 and lora_indices_tensor == -1:
no_lora = True
return (
b_seq_start_tensor,
seq_length_tensor,
lora_indices_tensor,
batch_size,
max_length,
token_nums,
no_lora,
)
# TODO see if this can be vectorized
def convert_mapping(
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
extra_vocab_size: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[int]]:
"""Converts LoRAMapping to index tensors.
Args:
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
lora_index_to_id: List mapping LoRA ids to LoRA indices.
max_loras: Maximum number of LoRAs.
vocab_size: Model vocab size.
extra_vocab_size: Extra vocab size each LoRA can have.
Returns:
A tuple of tensors:
base_indices: Tensor of shape [batch_size] mapping batch rows to
LoRA indices.
sampler_indices: Tensor of shape [batch_size] mapping requests to
LoRA indices for sampler. For generation, this will be the
same as base_indices. For prefill, this will map requests
to LoRA indices.
sampler_indices_padded: Tensor of shape [batch_size] mapping
requests to LoRA indices for sampler with padding.
Same as sampler_indices, but -1 is replaced with
max_loras.
embeddings_indices: Tensor of shape [2, batch_size] mapping
requests to embedding indices. First row is for embeddings
added by the LoRAs, second row is for the LoRA.lora_a
embeddings.
indices_len: List of lengths of the above tensors. It contains
(base_indices, sampler_indices, sampler_indices_padded,
embeddings_indices).
"""
index_mapping_indices: list[int] = list(mapping.index_mapping).copy()
embedding_indices = index_mapping_indices.copy()
lora_indices = index_mapping_indices.copy()
prompt_mapping: list[int] = [
lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping
]
lora_idx = None
for i in range(len(index_mapping_indices)):
# TODO index can be slow. optimize
lora_idx = (
lora_index_to_id.index(index_mapping_indices[i])
if index_mapping_indices[i] > 0
else -1
)
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
lora_indices[i] = lora_idx
indices_list: list[list[int] | torch.Tensor] = [
index_mapping_indices,
lora_indices,
embedding_indices,
]
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
prompt_mapping_tensor = torch.tensor(
prompt_mapping, dtype=torch.long, device=device
)
embeddings_indices = torch.stack(
[
indices[2] * extra_vocab_size,
indices[2] * (vocab_size + extra_vocab_size),
]
)
embeddings_indices = torch.where(
embeddings_indices == -1, max_loras - 1, embeddings_indices
)
base_indices = indices[1]
sampler_indices = prompt_mapping_tensor
sampler_indices_padded = sampler_indices.clone()
sampler_indices_padded = torch.where(
sampler_indices_padded == -1, max_loras - 1, sampler_indices_padded
)
sampler_indices_padded = torch.arange(
0, len(sampler_indices_padded), device=device, dtype=torch.long
) + (sampler_indices_padded * len(sampler_indices_padded))
# Contain length of indices tensors. Used to index into each tensor.
indices_len = [
base_indices.shape[-1],
sampler_indices.shape[-1],
sampler_indices_padded.shape[-1],
embeddings_indices.shape[-1],
]
return (
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
)

66
vllm/lora/request.py Normal file
View File

@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import msgspec
class LoRARequest(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True,
): # type: ignore[call-arg]
"""
Request for a LoRA adapter.
lora_int_id must be globally unique for a given adapter.
This is currently not enforced in vLLM.
load_inplace: If True, forces reloading the adapter even if one
with the same lora_int_id already exists in the cache. This replaces
the existing adapter in-place. If False (default), only loads if the
adapter is not already loaded.
"""
lora_name: str
lora_int_id: int
lora_path: str = ""
base_model_name: str | None = msgspec.field(default=None)
tensorizer_config_dict: dict | None = None
load_inplace: bool = False
def __post_init__(self):
if self.lora_int_id < 1:
raise ValueError(f"id must be > 0, got {self.lora_int_id}")
# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"
@property
def adapter_id(self):
return self.lora_int_id
@property
def name(self):
return self.lora_name
@property
def path(self):
return self.lora_path
def __eq__(self, value: object) -> bool:
"""
Overrides the equality method to compare LoRARequest
instances based on lora_name. This allows for identification
and comparison lora adapter across engines.
"""
return isinstance(value, self.__class__) and self.lora_name == value.lora_name
def __hash__(self) -> int:
"""
Overrides the hash method to hash LoRARequest instances
based on lora_name. This ensures that LoRARequest instances
can be used in hash-based collections such as sets and dictionaries,
identified by their names across engines.
"""
return hash(self.lora_name)

88
vllm/lora/resolver.py Normal file
View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Set
from dataclasses import dataclass, field
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
logger = init_logger(__name__)
class LoRAResolver(ABC):
"""Base class for LoRA adapter resolvers.
This class defines the interface for resolving and fetching LoRA adapters.
Implementations of this class should handle the logic for locating and
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
etc.).
"""
@abstractmethod
async def resolve_lora(
self, base_model_name: str, lora_name: str
) -> LoRARequest | None:
"""Abstract method to resolve and fetch a LoRA model adapter.
Implements logic to locate and download LoRA adapter based on the name.
Implementations might fetch from a blob storage or other sources.
Args:
base_model_name: The name/identifier of the base model to resolve.
lora_name: The name/identifier of the LoRA model to resolve.
Returns:
Optional[LoRARequest]: The resolved LoRA model information, or None
if the LoRA model cannot be found.
"""
pass
@dataclass
class _LoRAResolverRegistry:
resolvers: dict[str, LoRAResolver] = field(default_factory=dict)
def get_supported_resolvers(self) -> Set[str]:
"""Get all registered resolver names."""
return self.resolvers.keys()
def register_resolver(
self,
resolver_name: str,
resolver: LoRAResolver,
) -> None:
"""Register a LoRA resolver.
Args:
resolver_name: Name to register the resolver under.
resolver: The LoRA resolver instance to register.
"""
if resolver_name in self.resolvers:
logger.warning(
"LoRA resolver %s is already registered, and will be "
"overwritten by the new resolver instance %s.",
resolver_name,
resolver,
)
self.resolvers[resolver_name] = resolver
def get_resolver(self, resolver_name: str) -> LoRAResolver:
"""Get a registered resolver instance by name.
Args:
resolver_name: Name of the resolver to get.
Returns:
The resolver instance.
Raises:
KeyError: If the resolver is not found in the registry.
"""
if resolver_name not in self.resolvers:
raise KeyError(
f"LoRA resolver '{resolver_name}' not found. "
f"Available resolvers: {list(self.resolvers.keys())}"
)
return self.resolvers[resolver_name]
LoRAResolverRegistry = _LoRAResolverRegistry()

311
vllm/lora/utils.py Normal file
View File

@@ -0,0 +1,311 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from typing import TYPE_CHECKING
import huggingface_hub
from huggingface_hub.utils import HfHubHTTPError, HFValidationError
from torch import nn
from transformers import PretrainedConfig
from vllm import envs
from vllm.config.lora import LoRAConfig
from vllm.logger import init_logger
# being imported for _all_lora_classes below
from vllm.lora.layers import (
BaseLayerWithLoRA,
ColumnParallelLinearWithLoRA,
ColumnParallelLinearWithShardedLoRA,
FusedMoE3DWithLoRA,
FusedMoEWithLoRA,
LogitsProcessorWithLoRA,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedQKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
QKVParallelLinearWithLoRA,
QKVParallelLinearWithShardedLoRA,
ReplicatedLinearWithLoRA,
RowParallelLinearWithLoRA,
RowParallelLinearWithShardedLoRA,
VocabParallelEmbeddingWithLoRA,
)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.utils import get_moe_expert_mapping, get_packed_modules_mapping
if TYPE_CHECKING:
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]:
"""
Returns num_active_loras values for cudagraph capture.
When specialize=True: powers of 2 up to max_loras, plus max_loras + 1.
When specialize=False: just [max_loras + 1].
This is the single source of truth for LoRA capture cases, used by both
CudagraphDispatcher and PunicaWrapperGPU.
"""
if not specialize:
return [max_loras + 1]
return [
n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1
]
_GLOBAL_LORA_ID = 0
def get_lora_id():
global _GLOBAL_LORA_ID
_GLOBAL_LORA_ID += 1
return _GLOBAL_LORA_ID
_all_lora_classes: set[type[BaseLayerWithLoRA]] = {
VocabParallelEmbeddingWithLoRA,
ColumnParallelLinearWithLoRA,
MergedColumnParallelLinearWithLoRA,
QKVParallelLinearWithLoRA,
MergedQKVParallelLinearWithLoRA,
RowParallelLinearWithLoRA,
ReplicatedLinearWithLoRA,
LogitsProcessorWithLoRA,
ColumnParallelLinearWithShardedLoRA,
QKVParallelLinearWithShardedLoRA,
MergedColumnParallelLinearWithShardedLoRA,
MergedColumnParallelLinearVariableSliceWithLoRA,
MergedQKVParallelLinearWithShardedLoRA,
RowParallelLinearWithShardedLoRA,
FusedMoEWithLoRA,
FusedMoE3DWithLoRA,
}
def is_moe_model(model: nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers and warns the user."""
if any(isinstance(module, FusedMoE) for module in model.modules()):
logger.info_once("MoE model detected. Using fused MoE LoRA implementation.")
return True
return False
def from_layer(
layer: nn.Module,
max_loras: int,
lora_config: LoRAConfig,
packed_modules_list: list,
model_config: PretrainedConfig | None = None,
) -> nn.Module:
for lora_cls in _all_lora_classes:
# specifying kwargs so they can be easily accessed in decorator
if lora_cls.can_replace_layer(
source_layer=layer,
lora_config=lora_config,
packed_modules_list=packed_modules_list,
model_config=model_config,
):
instance_layer = lora_cls(layer)
instance_layer.create_lora_weights(max_loras, lora_config, model_config)
return instance_layer
return layer
def from_layer_logits_processor(
layer: "LogitsProcessor",
lm_head: "ParallelLMHead",
max_loras: int,
lora_config: LoRAConfig,
model_config: PretrainedConfig | None = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(
layer,
lm_head.embedding_dim,
lm_head.weight.dtype,
lm_head.weight.device,
lm_head.get_sharded_to_full_mapping(),
)
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret
def replace_submodule(
model: nn.Module, module_name: str, new_module: nn.Module
) -> nn.Module:
"""Replace a submodule in a model with a new module."""
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
target_name = module_name.split(".")[-1]
setattr(parent, target_name, new_module)
return new_module
def parse_fine_tuned_lora_name(
name: str, weights_mapper: "WeightsMapper | None" = None
) -> tuple[str, bool]:
"""Parse the name of lora weights.
args:
name: the name of the fine-tuned LoRA, e.g.
base_model.model.dense1.weight
weights_mapper: maps the name of weight, e.g.
`model.` -> `language_model.model.`,
return:
tuple(module_name, is_lora_a):
module_name: the name of the module, e.g. model.dense1,
is_lora_a whether the tensor is lora_a or lora_b.
"""
# LoRA weight qualified name usually starts with `base_model.model.`,
# so we remove the prefix `base_model.model.` to make the following
# mapping correctly.
if name.startswith("base_model.model."):
name = name.replace("base_model.model.", "")
name = weights_mapper._map_name(name) if weights_mapper else name
# recover the prefix `base_model.model.`
name = "base_model.model." + name
else:
name = weights_mapper._map_name(name) if weights_mapper else name
# In some situations, we may not start with `base_model.model.`.
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
# we should keep the prefix intact.
start_index = 2 if name.startswith("base_model.model.") else 0
parts = name.split(".")
if parts[-1] == "weight" and (parts[-2] == "lora_A" or parts[-2] == "lora_B"):
new_name = ".".join(parts[start_index:-2])
return new_name, parts[-2] == "lora_A"
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
new_name = ".".join(parts[start_index:-1])
return new_name, parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported LoRA weight")
def is_base_embeddding_weights(name: str) -> bool:
# hardcoded subfixes for input & output embedding weights
embedding_suffixes = (
".embed_tokens.base_layer.weight",
".lm_head.base_layer.weight",
)
return name.endswith(embedding_suffixes)
def get_supported_lora_modules(model: nn.Module) -> list[str]:
"""
In vLLM, all linear layers support LoRA.
"""
supported_lora_modules: set[str] = set()
for name, module in model.named_modules():
# get the embedding modules if the module's embedding_modules
# is not empty.
embedding_modules = getattr(module, "embedding_modules", None)
if embedding_modules is not None:
for name in embedding_modules:
supported_lora_modules.add(name)
# get all the linear subfixes.
if isinstance(module, (LinearBase,)):
supported_lora_modules.add(name.split(".")[-1])
if isinstance(module, (FusedMoE,)):
supported_lora_modules.add(name.split(".")[-1])
return list(supported_lora_modules)
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.
Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.
Returns:
str: The resolved absolute local path to the lora model.
"""
# Check if the path is an absolute path. Return it no matter exists or not.
if os.path.isabs(lora_path):
return lora_path
# If the path starts with ~, expand the user home directory.
if lora_path.startswith("~"):
return os.path.expanduser(lora_path)
# Check if the expanded relative path exists locally.
if os.path.exists(lora_path):
return os.path.abspath(lora_path)
# If the path does not exist locally.
if envs.VLLM_USE_MODELSCOPE:
# If using ModelScope, we assume the path is a ModelScope repo.
from modelscope.hub.snapshot_download import InvalidParameter, snapshot_download
from requests import HTTPError
download_fn = lambda: snapshot_download(model_id=lora_path)
download_exceptions = (HTTPError, InvalidParameter)
error_log = "Error downloading the ModelScope model"
else:
# Otherwise, we assume the path is a Hugging Face Hub repo.
download_fn = lambda: huggingface_hub.snapshot_download(repo_id=lora_path)
download_exceptions = (HfHubHTTPError, HFValidationError)
error_log = "Error downloading the HuggingFace model"
try:
local_snapshot_path = download_fn()
except download_exceptions:
# Handle errors that may occur during the download.
# Return original path instead of throwing error here.
logger.exception(error_log)
return lora_path
return local_snapshot_path
def process_packed_modules_mapping(model: nn.Module) -> dict[str, list[str]]:
if is_moe_model(model):
if moe_packed_mapping := get_moe_expert_mapping(model):
# This method generates and returns a dictionary mapping packed module
# names to lists of their corresponding submodule names. It includes
# both static mappings and dynamic mappings for expert layers, where
# the expert indices are expanded based on the configured number
# of routed experts.
packed_modules_mapping = get_packed_modules_mapping(model)
if not model.is_3d_moe_weight:
# 3D MoE LoRA does not need `packed_modules_mapping`
# Filter out malformed entries: non-gated MoE has empty
# ckpt_up_proj_name which results in weight_name containing ".."
# (e.g., "experts.0.." instead of "experts.0.layer_name.")
packed_modules_mapping["experts"] = [
weight_name.rstrip(".")
for _, weight_name, _, _ in moe_packed_mapping
if ".." not in weight_name
]
return packed_modules_mapping
else:
raise AttributeError(
"To support LoRA for MoE model, "
"'get_expert_mapping' must be implemented"
)
else:
return get_packed_modules_mapping(model)

289
vllm/lora/worker_manager.py Normal file
View File

@@ -0,0 +1,289 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any, Literal
import torch
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.lora_model import LoRAModel
from vllm.lora.model_manager import (
LoRAModelManager,
LRUCacheLoRAModelManager,
create_lora_manager,
)
from vllm.lora.peft_helper import PEFTHelper
from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
logger = init_logger(__name__)
class WorkerLoRAManager:
"""WorkerLoRAManager that manages LoRA models on the worker side.
Every request, the requested LoRAs will be loaded (unless they are already
loaded), and every other LoRA will be unloaded."""
_manager_cls: type[LoRAModelManager] = LoRAModelManager
def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
embedding_modules: dict[str, str],
lora_model_cls: type[LoRAModel] = LoRAModel,
):
self._lora_model_cls = lora_model_cls
self.embedding_modules = embedding_modules
self._cached_dummy_lora: None | Literal[False] | LoRAModel = False
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
self.max_num_batched_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens
)
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.lora_config = vllm_config.lora_config
# Use get_text_config() in case of multimodal models
text_config = vllm_config.model_config.hf_config.get_text_config()
self.max_position_embeddings = text_config.max_position_embeddings
self.device = device
# Lazily initialized by create_lora_manager.
self._adapter_manager: LoRAModelManager
@contextmanager
def dummy_lora_cache(self):
"""Use this context manager to reuse the dummy lora model
to avoid creating it repeatedly."""
self._cached_dummy_lora = None
yield
self._cached_dummy_lora = False
@property
def is_enabled(self) -> bool:
return True
def create_lora_manager(
self,
model: torch.nn.Module,
vllm_config: VllmConfig | None = None,
) -> Any:
lora_manager = create_lora_manager(
model,
max_num_seqs=self.max_num_seqs,
max_num_batched_tokens=self.max_num_batched_tokens,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
lora_manager_cls=self._manager_cls,
vllm_config=vllm_config,
)
self._adapter_manager = lora_manager
return lora_manager.model
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
try:
supported_lora_modules = self._adapter_manager.supported_lora_modules
packed_modules_mapping = self._adapter_manager.packed_modules_mapping
expected_lora_lst: list[str] = []
for module in supported_lora_modules:
if module in packed_modules_mapping:
expected_lora_lst.extend(packed_modules_mapping[module])
else:
expected_lora_lst.append(module)
if module == "experts":
expected_lora_lst.append(module)
expected_lora_modules = set(expected_lora_lst)
lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir(
lora_path,
self.max_position_embeddings,
lora_request.tensorizer_config_dict,
)
# Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails.
peft_helper.validate_legal(self.lora_config)
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
# to ensure correct loading of lora weights.
model = self._adapter_manager.model
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
# Get model-defined prefixes to skip during LoRA loading.
lora_skip_prefixes = getattr(model, "lora_skip_prefixes", None)
lora = self._lora_model_cls.from_local_checkpoint(
lora_path,
expected_lora_modules,
peft_helper=peft_helper,
lora_model_id=lora_request.lora_int_id,
device="cpu",
dtype=self.lora_config.lora_dtype,
model_vocab_size=self.vocab_size,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
weights_mapper=hf_to_vllm_mapper,
skip_prefixes=lora_skip_prefixes,
)
except FileNotFoundError as e:
# FileNotFoundError should be raised if both
# - No adapter found to download from huggingface (or in
# offline mode)
# - No local adapter files found at `lora_request.lora_path`
# For NotFoundError
raise ValueError(
f"Loading lora {lora_request.lora_name} failed: No adapter "
f"found for {lora_request.lora_path}"
) from e
except Exception as e:
# For BadRequestError
raise e
return lora
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
if lora_request.lora_int_id in self.list_adapters():
return False
if isinstance(self._cached_dummy_lora, LoRAModel):
dummy_lora = self._cached_dummy_lora.clone(lora_request.lora_int_id)
else:
dummy_lora = self._adapter_manager.create_dummy_lora(
lora_request.lora_int_id, rank, self.embedding_modules
)
if self._cached_dummy_lora is None:
self._cached_dummy_lora = dummy_lora
return self._adapter_manager.add_adapter(dummy_lora)
def pin_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.pin_adapter(adapter_id)
def set_active_adapters(self, requests: set[Any], mapping: Any | None) -> None:
self._apply_adapters(requests)
if mapping is not None:
self._adapter_manager.set_adapter_mapping(mapping)
def supports_tower_connector_lora(self) -> bool:
return (
self._adapter_manager.supports_mm
and self._adapter_manager.supports_tower_connector_lora
)
def _apply_adapters(self, adapter_requests: set[Any]) -> None:
existing_adapters = self.list_adapters()
models_map = {
adapter_request.adapter_id: adapter_request
for adapter_request in adapter_requests
if adapter_request
}
if len(models_map) > self._adapter_manager.adapter_slots:
raise RuntimeError(
f"Number of requested models ({len(models_map)}) is greater "
"than the number of GPU model slots "
f"({self._adapter_manager.adapter_slots})."
)
requested_ids = set(models_map)
for adapter_id in existing_adapters - requested_ids:
self.remove_adapter(adapter_id)
for adapter_id in requested_ids - existing_adapters:
self.add_adapter(models_map[adapter_id])
def add_adapter(self, adapter_request: Any) -> bool:
if adapter_request.adapter_id in self.list_adapters():
return False
loaded_adapter = self._load_adapter(adapter_request)
loaded = self._adapter_manager.add_adapter(loaded_adapter)
self._adapter_manager.activate_adapter(loaded_adapter.id)
return loaded
def remove_adapter(self, adapter_id: int) -> bool:
return self._adapter_manager.remove_adapter(adapter_id)
def remove_all_adapters(self):
self._adapter_manager.remove_all_adapters()
def list_adapters(self) -> set[int]:
return set(self._adapter_manager.list_adapters())
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
"""WorkerLoRAManager that manages LoRA models on the worker side.
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
(unless they are already loaded) and least recently used LoRAs will
be unloaded if the cache is above capacity."""
_manager_cls: type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
def create_lora_manager(
self,
model: torch.nn.Module,
vllm_config: VllmConfig | None = None,
) -> Any:
lora_manager = create_lora_manager(
model,
lora_manager_cls=self._manager_cls,
max_num_seqs=self.max_num_seqs,
vocab_size=self.vocab_size,
lora_config=self.lora_config,
device=self.device,
max_num_batched_tokens=self.max_num_batched_tokens,
vllm_config=vllm_config,
)
self._adapter_manager = lora_manager
return lora_manager.model
def _apply_adapters(self, lora_requests: set[LoRARequest]) -> None:
loras_map = {
lora_request.lora_int_id: lora_request
for lora_request in lora_requests
if lora_request
}
if len(loras_map) > self._adapter_manager.lora_slots:
raise RuntimeError(
f"Number of requested LoRAs ({len(loras_map)}) is greater "
"than the number of GPU LoRA slots "
f"({self._adapter_manager.lora_slots})."
)
for lora in loras_map.values():
self.add_adapter(lora)
def add_adapter(self, lora_request: LoRARequest) -> bool:
# Note that this method is not thread-safe. It may be invoked multiple
# times for the same adapter when using multiple API servers.
# This is ok because it's currently only called from
# the single-threaded core engine loop.
if (
lora_request.lora_int_id not in self.list_adapters()
or lora_request.load_inplace
):
# Load the new adapter first to ensure it is actually valid, before
# evicting any existing adapters.
# This may cause the # of loaded lora adapters to very temporarily
# exceed `--max-cpu-loras`.
lora = self._load_adapter(lora_request)
# Remove the existing adapter if it exists
# Use case for LoRA inplace
self._adapter_manager.remove_adapter(lora.id)
# Loading succeeded, now check if we will exceed cache capacity and
# evict if the oldest adapter if so
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
assert isinstance(self._adapter_manager, LRUCacheLoRAModelManager)
self._adapter_manager.remove_oldest_adapter()
# Then add the new adapter to the cache
loaded = self._adapter_manager.add_adapter(lora)
else:
# If the lora is already loaded, just touch it to
# update its position in the caches
loaded = (
self._adapter_manager.get_adapter(lora_request.lora_int_id) is not None
)
self._adapter_manager.activate_adapter(lora_request.lora_int_id)
return loaded