forked from EngineX-Cambricon/enginex-mlu370-vllm
add qwen3
This commit is contained in:
0
vllm-v0.6.2/vllm/lora/__init__.py
Normal file
0
vllm-v0.6.2/vllm/lora/__init__.py
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/__pycache__/layers.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/layers.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/__pycache__/lora.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/lora.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/__pycache__/models.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/models.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/__pycache__/punica.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/punica.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/__pycache__/request.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/request.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/__pycache__/worker_manager.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/__pycache__/worker_manager.cpython-310.pyc
Normal file
Binary file not shown.
377
vllm-v0.6.2/vllm/lora/fully_sharded_layers.py
Normal file
377
vllm-v0.6.2/vllm/lora/fully_sharded_layers.py
Normal file
@@ -0,0 +1,377 @@
|
||||
# pylint: disable=unused-argument
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
return (can_replace(*args, **kwargs)
|
||||
and kwargs["lora_config"].fully_sharded_loras)
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
# these layers are based on the tensor parallelism strategy given in
|
||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||
# https://arxiv.org/abs/2311.03285.
|
||||
|
||||
|
||||
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked.shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros(
|
||||
(x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
buffer = tensor_model_parallel_all_gather(buffer)
|
||||
self.punica_wrapper.add_expand(output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
# now have column partitioned output
|
||||
|
||||
if self.bias_stacked is not None:
|
||||
self.bias_stacked = self.bias_stacked.view(
|
||||
-1, self.bias_stacked.shape[-1])
|
||||
self.bias_stacked = self.bias_stacked[
|
||||
self.punica_wrapper.token_lora_indices]
|
||||
output += self.bias_stacked
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
|
||||
"""
|
||||
MergedColumnParallelLinearWithShardedLoRA and
|
||||
MergedQKVParallelLinearWithShardedLora share the same
|
||||
LoRa weight application method.
|
||||
|
||||
The main difference is the step by shard_size for lora_b which can
|
||||
vary for MergedQKVParallelLinearWithShardedLora but is constant for
|
||||
MergedColumnParallelLinearWithShardedLoRA.
|
||||
"""
|
||||
# expecting 2 for column parallel and 3 for qkv
|
||||
n = len(layer.lora_a_stacked)
|
||||
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffers = torch.zeros(
|
||||
(n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
for idx in range(n):
|
||||
layer.punica_wrapper.add_shrink(buffers[idx], x,
|
||||
layer.lora_a_stacked[idx], 1.0)
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
left_offset = 0
|
||||
for idx in range(n):
|
||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
||||
|
||||
if layer.bias_stacked is not None:
|
||||
bias = layer.bias_stacked[idx]
|
||||
if bias is not None:
|
||||
bias = bias.view(-1, bias.shape[-1])
|
||||
bias = bias[layer.punica_wrapper.token_lora_indices]
|
||||
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
|
||||
output[:, left_offset:left_offset + shard_size] += bias
|
||||
|
||||
layer.punica_wrapper.add_expand_slice(
|
||||
output,
|
||||
buffers[idx],
|
||||
layer.lora_b_stacked[idx],
|
||||
left_offset,
|
||||
shard_size,
|
||||
add_input=True,
|
||||
)
|
||||
left_offset += shard_size
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
# now have column partitioned and packed output
|
||||
return output
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithShardedLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
|
||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||
output_start_idx = self.tp_rank * output_shard_size
|
||||
lora_a = [
|
||||
lora_a[0][:, output_start_idx:output_start_idx +
|
||||
output_shard_size] if lora_a[0] is not None else None,
|
||||
lora_a[1][:, output_start_idx:output_start_idx +
|
||||
output_shard_size] if lora_a[1] is not None else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
||||
"""
|
||||
Differs from QKVParallelLinearWithLora by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.lora_a_stacked.shape[2]
|
||||
start_idx = tp_rank * shard_size
|
||||
lora_a = lora_a[:, start_idx:start_idx + shard_size]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
buffer = tensor_model_parallel_all_gather(buffer)
|
||||
self.punica_wrapper.add_expand(output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
# now have column partitioned output
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
||||
"""
|
||||
Differs from MergedQKVParallelLinearWithLora by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: List[Union[torch.Tensor, None]]
|
||||
) -> List[Union[torch.Tensor, None]]:
|
||||
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
|
||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||
lora_a = [
|
||||
lora_a[0][:, start_idx[0]:start_idx[0] +
|
||||
shard_size[0]] if lora_a[0] is not None else None,
|
||||
lora_a[1][:, start_idx[1]:start_idx[1] +
|
||||
shard_size[1]] if lora_a[1] is not None else None,
|
||||
lora_a[2][:, start_idx[2]:start_idx[2] +
|
||||
shard_size[2]] if lora_a[2] is not None else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from RowParallelLinearWithLoRA by slicing the
|
||||
LoRA B's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the output dim.
|
||||
This yields a combined partial sum from the row parallel base
|
||||
layer and column partitioned output from the LoRA.
|
||||
"""
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[:, start_idx:end_idx]
|
||||
return lora_b
|
||||
|
||||
def slice_bias(self, bias: torch.Tensor) -> torch.Tensor:
|
||||
if bias is None:
|
||||
return bias
|
||||
shard_size = self.bias_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
bias = bias[start_idx:end_idx]
|
||||
return bias
|
||||
|
||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros(
|
||||
(x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if self.bias_stacked is not None:
|
||||
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
|
||||
bias = bias[self.punica_wrapper.token_lora_indices]
|
||||
bias[self.punica_wrapper.token_lora_indices == -1] = 0
|
||||
output += bias
|
||||
|
||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
||||
self.lora_b_stacked, start_idx,
|
||||
shard_size)
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
1607
vllm-v0.6.2/vllm/lora/layers.py
Normal file
1607
vllm-v0.6.2/vllm/lora/layers.py
Normal file
File diff suppressed because it is too large
Load Diff
184
vllm-v0.6.2/vllm/lora/lora.py
Normal file
184
vllm-v0.6.2/vllm/lora/lora.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from typing import List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
|
||||
import torch
|
||||
import torch.types
|
||||
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
class LoRALayerWeights:
|
||||
"""LoRA weights for a layer composed of two low rank matrixes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
rank: int,
|
||||
lora_alpha: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||
scaling: Optional[float] = None,
|
||||
) -> None:
|
||||
self.module_name = module_name
|
||||
self.rank = rank
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_a = lora_a
|
||||
self.lora_b = lora_b
|
||||
self.bias = bias
|
||||
self.embeddings_tensor = embeddings_tensor
|
||||
|
||||
if scaling is None:
|
||||
self.scaling = self.lora_alpha / self.rank
|
||||
else:
|
||||
self.scaling = scaling
|
||||
|
||||
def optimize(self) -> "LoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
if self.scaling == 1:
|
||||
return self
|
||||
self.lora_b *= self.scaling
|
||||
self.scaling = 1
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_dim(self) -> int:
|
||||
return self.lora_a.shape[0]
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self.lora_b.shape[1]
|
||||
|
||||
@property
|
||||
def is_packed(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return self.embeddings_tensor.shape[
|
||||
0] if self.embeddings_tensor is not None else 0
|
||||
|
||||
@classmethod
|
||||
def create_dummy_lora_weights(
|
||||
cls,
|
||||
module_name: str,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.types.Device,
|
||||
embeddings_tensor_dim: Optional[int] = None,
|
||||
bias_enabled: Optional[bool] = False) -> "LoRALayerWeights":
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
lora_a = torch.zeros([input_dim, rank],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
lora_b = torch.zeros([rank, output_dim],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
if bias_enabled:
|
||||
bias = torch.zeros([output_dim],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
else:
|
||||
bias = None
|
||||
|
||||
embeddings_tensor = torch.rand(
|
||||
10,
|
||||
embeddings_tensor_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory) if embeddings_tensor_dim else None
|
||||
return cls(
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
bias=bias,
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
|
||||
|
||||
class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
"""LoRA used for packed layers (eg. qkv_proj)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
rank: int,
|
||||
lora_alphas: List[Optional[int]],
|
||||
lora_a: List[Optional[torch.Tensor]],
|
||||
lora_b: List[Optional[torch.Tensor]],
|
||||
bias: Optional[List[Optional[torch.Tensor]]] = None,
|
||||
scaling: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
module_name=module_name,
|
||||
rank=rank,
|
||||
lora_alpha=0,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
bias=bias,
|
||||
scaling=scaling, # type: ignore
|
||||
embeddings_tensor=None,
|
||||
)
|
||||
self.lora_alphas = lora_alphas
|
||||
if scaling is None:
|
||||
self.scaling = [ # type: ignore
|
||||
lora_alpha / self.rank # type: ignore # noqa
|
||||
for lora_alpha in self.lora_alphas
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def pack(
|
||||
cls, loras: GenericSequence[Optional["LoRALayerWeights"]]
|
||||
) -> "PackedLoRALayerWeights":
|
||||
"""Pack a list of LoRAs into a single LoRA.
|
||||
|
||||
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
||||
"""
|
||||
first_lora = next(lora for lora in loras if lora is not None)
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
lora.optimize()
|
||||
rank = first_lora.rank
|
||||
module_name = first_lora.module_name
|
||||
obj = cls(
|
||||
module_name,
|
||||
rank,
|
||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||
[lora.lora_a if lora is not None else None for lora in loras],
|
||||
[lora.lora_b if lora is not None else None for lora in loras],
|
||||
[lora.bias if lora is not None else None for lora in loras],
|
||||
scaling=[
|
||||
1 if lora is not None else None # type: ignore
|
||||
for lora in loras
|
||||
])
|
||||
return obj
|
||||
|
||||
def optimize(self) -> "PackedLoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
for i in range(len(self.lora_b)):
|
||||
if self.scaling[i] == 1 or self.lora_b[i] is None: # type: ignore
|
||||
continue
|
||||
self.lora_b[i] *= self.scaling[i] # type: ignore
|
||||
self.scaling[i] = 1 # type: ignore
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_dim(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_packed(self) -> bool:
|
||||
return True
|
||||
770
vllm-v0.6.2/vllm/lora/models.py
Normal file
770
vllm-v0.6.2/vllm/lora/models.py
Normal file
@@ -0,0 +1,770 @@
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel,
|
||||
AdapterModelManager)
|
||||
from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter,
|
||||
get_adapter, list_adapters,
|
||||
remove_adapter, set_adapter_mapping)
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLora,
|
||||
LoRAMapping)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
is_regex_target_modules,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.utils import PPMissingLayer
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GLOBAL_LORA_ID = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class LongContextLoRAContext:
|
||||
"""Context for lora adapters that support long context."""
|
||||
# The scaling factors to support long context lora fine tuned models.
|
||||
scaling_factors: List[float]
|
||||
# dimension to apply rotary embedding.
|
||||
rot_dim: int
|
||||
# offsets to the sin_cos_cache for each lora_id loaded.
|
||||
# This value is dynamically modified.
|
||||
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def get_lora_id():
|
||||
global _GLOBAL_LORA_ID
|
||||
_GLOBAL_LORA_ID += 1
|
||||
return _GLOBAL_LORA_ID
|
||||
|
||||
|
||||
class LoRAModel(AdapterModel):
|
||||
"""A LoRA fine-tuned model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
loras: Dict[str, LoRALayerWeights],
|
||||
scaling_factor: Optional[float] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
lora_model_id: The integer id for the lora model.
|
||||
rank: lora rank.
|
||||
loras: module name -> weights for lora-replaced layers.
|
||||
scaling_factor: Scaling factor to support long context lora model.
|
||||
None if the lora is not tuned for long context support.
|
||||
"""
|
||||
self.id = lora_model_id
|
||||
# Scaling factor for long context lora model. None if it is not
|
||||
# fine tuned for the long context.
|
||||
self.scaling_factor = scaling_factor
|
||||
assert (lora_model_id >
|
||||
0), f"a valid lora id should be greater than 0, got {self.id}"
|
||||
self.rank = rank
|
||||
self.loras: Dict[str, LoRALayerWeights] = loras
|
||||
|
||||
def clone(self, lora_model_id: int) -> "LoRAModel":
|
||||
"""Return a copy of the object with different ids.
|
||||
|
||||
Will share the underlying tensors."""
|
||||
return self.__class__(
|
||||
lora_model_id,
|
||||
rank=self.rank,
|
||||
loras=self.loras.copy(),
|
||||
)
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return max(lora.extra_vocab_size
|
||||
for lora in self.loras.values()) if self.loras else 0
|
||||
|
||||
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
|
||||
"""Get LoRA for a given module by name"""
|
||||
return self.loras.get(module_name, None)
|
||||
|
||||
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
||||
@classmethod
|
||||
def from_lora_tensors(
|
||||
cls,
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
lora_alpha: int,
|
||||
tensors: Dict[str, torch.Tensor],
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
scaling_factor: Optional[float] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a dictionary of tensors."""
|
||||
pin_memory = str(device) == "cpu" and is_pin_memory_available()
|
||||
loras: Dict[str, LoRALayerWeights] = {}
|
||||
for tensor_name, tensor in tensors.items():
|
||||
module_name, is_lora_a, is_bias = parse_fine_tuned_lora_name(
|
||||
tensor_name)
|
||||
if module_name not in loras:
|
||||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
assert embedding_modules is not None
|
||||
embeddings_module = next(
|
||||
(k for k in embedding_modules if k in module_name),
|
||||
None)
|
||||
if embeddings_module:
|
||||
lora_embeddings_tensor = embeddings[
|
||||
embedding_modules[embeddings_module]].to(
|
||||
device=device, dtype=dtype)
|
||||
if pin_memory:
|
||||
lora_embeddings_tensor = (
|
||||
lora_embeddings_tensor.pin_memory())
|
||||
loras[module_name] = LoRALayerWeights(module_name, rank,
|
||||
lora_alpha, None, None,
|
||||
None,
|
||||
lora_embeddings_tensor)
|
||||
if is_bias:
|
||||
loras[module_name].bias = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
bias = tensor.to(device=device, dtype=dtype).t()
|
||||
if pin_memory:
|
||||
bias = bias.pin_memory()
|
||||
loras[module_name].bias = bias
|
||||
elif is_lora_a:
|
||||
loras[module_name].lora_a = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
if pin_memory:
|
||||
loras[module_name].lora_a = loras[
|
||||
module_name].lora_a.pin_memory()
|
||||
else:
|
||||
loras[module_name].lora_b = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
assert embedding_padding_modules is not None
|
||||
if any(name in module_name
|
||||
for name in embedding_padding_modules
|
||||
) and target_embedding_padding is not None:
|
||||
lora_b = loras[module_name].lora_b
|
||||
assert target_embedding_padding >= lora_b.shape[1]
|
||||
addition = target_embedding_padding - lora_b.shape[1]
|
||||
loras[module_name].lora_b = torch.nn.functional.pad(
|
||||
lora_b, (0, addition))
|
||||
if pin_memory:
|
||||
loras[module_name].lora_b = loras[
|
||||
module_name].lora_b.pin_memory()
|
||||
|
||||
for lora in loras.values():
|
||||
lora.optimize()
|
||||
return cls(lora_model_id, rank, loras, scaling_factor=scaling_factor)
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
lora_dir: str,
|
||||
expected_lora_modules: List[str],
|
||||
*,
|
||||
max_position_embeddings: Optional[int] = None,
|
||||
lora_model_id: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a local checkpoint.
|
||||
|
||||
Args:
|
||||
lora_dir: The local path that has lora data.
|
||||
expected_lora_modules: Name of modules that are expected to be
|
||||
replaced by lora.
|
||||
max_position_embeddings: Max position embedding length. Used to
|
||||
scaling the largest context length. If None, the lora model's
|
||||
context length is not scaled.
|
||||
lora_model_id: Lora model id. If not given, automatically set by
|
||||
a global counter.
|
||||
device: Device where the lora model is loaded.
|
||||
dtype: dtype of the lora model weights.
|
||||
|
||||
Returns:
|
||||
Loaded LoRA Model.
|
||||
"""
|
||||
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||
new_embeddings_tensor_path = os.path.join(
|
||||
lora_dir, "new_embeddings.safetensors")
|
||||
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
||||
"new_embeddings.bin")
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
if os.path.isfile(lora_tensor_path):
|
||||
tensors: Dict[str, torch.Tensor] = {}
|
||||
# Find unexpected modules.
|
||||
# Use safetensor key as a source of truth to find expected modules.
|
||||
# in peft if you have target_modules A, B, C and C does not exist
|
||||
# in the model it won’t error and model will be trained with A, B
|
||||
# loraified. C won’t exist in the safetensor but it will exist in
|
||||
# the target_modules of the adapter_config.json.
|
||||
unexpected_modules = []
|
||||
with safetensors.safe_open(lora_tensor_path,
|
||||
framework="pt") as f: # type: ignore
|
||||
for lora_module in f.keys(): # noqa
|
||||
module_name, _, _ = parse_fine_tuned_lora_name(lora_module)
|
||||
part_name = module_name.split(".")[-1]
|
||||
if part_name not in expected_lora_modules:
|
||||
unexpected_modules.append(module_name)
|
||||
if unexpected_modules:
|
||||
raise ValueError(
|
||||
f"While loading {lora_dir}, expected"
|
||||
f" target modules in {expected_lora_modules}"
|
||||
f" but received {unexpected_modules}."
|
||||
f" Please verify that the loaded LoRA module is correct"
|
||||
)
|
||||
# Load tensors if there are only expected modules.
|
||||
for module in f.keys(): # noqa
|
||||
tensors[module] = f.get_tensor(module)
|
||||
elif os.path.isfile(lora_bin_file_path):
|
||||
# When a bin file is provided, we rely on config to find unexpected
|
||||
# modules.
|
||||
unexpected_modules = []
|
||||
target_modules = config["target_modules"]
|
||||
if not isinstance(target_modules, list):
|
||||
target_modules = [target_modules]
|
||||
for module in target_modules:
|
||||
# Compatible with more modules,
|
||||
# such as:layers.11.self_attn.k_proj
|
||||
part_name = module.split(".")[-1]
|
||||
if part_name not in expected_lora_modules:
|
||||
unexpected_modules.append(module)
|
||||
# loaded lora's target modules must be a subset of
|
||||
# expected_lora_modules. It is not reliable. See
|
||||
# https://github.com/vllm-project/vllm/pull/5909. But there's no
|
||||
# other better mechanism.
|
||||
if unexpected_modules and not is_regex_target_modules(
|
||||
config["target_modules"], expected_lora_modules):
|
||||
raise ValueError(
|
||||
f"While loading {lora_dir}, expected"
|
||||
f" target modules in {expected_lora_modules}"
|
||||
f" but received {unexpected_modules}."
|
||||
f" Please verify that the loaded LoRA module is correct")
|
||||
tensors = torch.load(lora_bin_file_path, map_location=device)
|
||||
else:
|
||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||
|
||||
embeddings = None
|
||||
if os.path.isfile(new_embeddings_tensor_path):
|
||||
embeddings = safetensors.torch.load_file(
|
||||
new_embeddings_tensor_path)
|
||||
elif os.path.isfile(new_embeddings_bin_file_path):
|
||||
embeddings = torch.load(new_embeddings_bin_file_path,
|
||||
map_location=device)
|
||||
|
||||
rank = config["r"]
|
||||
lora_alpha = config["lora_alpha"]
|
||||
context_length = config.get("context_length", None)
|
||||
scaling_factor = None
|
||||
if context_length:
|
||||
if max_position_embeddings is None:
|
||||
max_position_embeddings = context_length
|
||||
scaling_factor = float(
|
||||
math.ceil(context_length / max_position_embeddings))
|
||||
|
||||
return cls.from_lora_tensors(
|
||||
lora_model_id=get_lora_id()
|
||||
if lora_model_id is None else lora_model_id,
|
||||
rank=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
tensors=tensors,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
embeddings=embeddings,
|
||||
target_embedding_padding=target_embedding_padding,
|
||||
scaling_factor=scaling_factor,
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embedding_padding_modules,
|
||||
)
|
||||
|
||||
|
||||
class LoRAModelManager(AdapterModelManager):
|
||||
"""A manager that manages multiple LoRA-fine-tuned models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: SupportsLoRA,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
max_num_seqs: the maximum number of sequences model can run in a
|
||||
single batch.
|
||||
max_num_batched_tokens: the maximum number of tokens model can run
|
||||
in a single batch.
|
||||
vocab_size: the vocab size of the model.
|
||||
lora_config: the LoRA configuration.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
self.device = device
|
||||
self.max_num_seqs = max_num_seqs
|
||||
assert self.capacity >= self.lora_slots
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
||||
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device=self.device)
|
||||
# Scaling factor -> offset to the sin_cos_cache to it.
|
||||
# Used for long context lora.
|
||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||
super().__init__(model)
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
self.model.supported_lora_modules)
|
||||
if lora_config.long_lora_scaling_factors:
|
||||
# We need to replace rotary emb layer to do batch computation
|
||||
# for long lora.
|
||||
self.supported_lora_modules.append("rotary_emb")
|
||||
self.packed_modules_mapping = copy.deepcopy(
|
||||
self.model.packed_modules_mapping)
|
||||
# Used to indicate whether the model is a multimodal model
|
||||
self.supports_mm: bool = (
|
||||
supports_multimodal(self.model)
|
||||
# In case the model only supports LoRA for
|
||||
# text modules (e.g. ChatGLM)
|
||||
and hasattr(self.model, "get_mm_mapping"))
|
||||
self.packed_modules: Dict[str, List[str]] = {}
|
||||
self.modules: Dict[str, BaseLayerWithLoRA] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self._last_mapping: Optional[LoRAMapping] = None
|
||||
self._create_lora_modules()
|
||||
self.model.lora_manager = self
|
||||
self.adapter_type = 'LoRa'
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
return self.lora_config.max_cpu_loras
|
||||
|
||||
@property
|
||||
def lora_slots(self) -> int:
|
||||
return self.lora_config.max_loras
|
||||
|
||||
@property
|
||||
def adapter_slots(self) -> int:
|
||||
return self.lora_slots
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
"""Move LoRA into a GPU buffer to be used in the forward pass."""
|
||||
if lora_id in self._active_adapters:
|
||||
return False
|
||||
first_free_slot = next(
|
||||
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
|
||||
if lora_id is None), None)
|
||||
if first_free_slot is None:
|
||||
raise ValueError("No free lora slots")
|
||||
index, _ = first_free_slot
|
||||
self._active_adapters[lora_id] = None
|
||||
lora_model = self._registered_adapters[lora_id]
|
||||
logger.debug("Activating LoRA. int id: %d, slot index: %d",
|
||||
lora_model.id, index)
|
||||
self.lora_index_to_id[index] = lora_model.id
|
||||
for module_name, module in self.modules.items():
|
||||
module_lora = lora_model.get_lora(module_name)
|
||||
if module_lora:
|
||||
module_lora.optimize()
|
||||
# Bias is not explicitly enabled with the flag enable_lora_bias.
|
||||
bias = module_lora.bias
|
||||
if ((torch.is_tensor(bias) or
|
||||
(isinstance(bias, Sequence) and any(b is not None
|
||||
for b in bias)))
|
||||
and not self.lora_config.bias_enabled):
|
||||
module_lora.bias = None
|
||||
raise ValueError(
|
||||
f"Adapter bias cannot be used for {module_name}"
|
||||
" without --enable-lora-bias.")
|
||||
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
||||
module_lora.embeddings_tensor,
|
||||
module_lora.bias)
|
||||
else:
|
||||
module.reset_lora(index)
|
||||
return True
|
||||
|
||||
def _deactivate_adapter(self, lora_id: int):
|
||||
try:
|
||||
index = self.lora_index_to_id.index(lora_id)
|
||||
self.lora_index_to_id[index] = None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def _set_long_lora_context(self, lora: LoRAModel):
|
||||
if self.long_lora_context is None:
|
||||
return
|
||||
|
||||
if lora.scaling_factor is None:
|
||||
return
|
||||
|
||||
if (lora.scaling_factor not in self.scaling_factor_to_offset):
|
||||
raise ValueError(f"Long LoRA scaling factor {lora.scaling_factor}"
|
||||
" has not been initialized.")
|
||||
|
||||
offsets = self.scaling_factor_to_offset.get(lora.scaling_factor)
|
||||
if offsets:
|
||||
self.long_lora_context.offsets_by_lora_id[lora.id] = offsets
|
||||
|
||||
def _add_adapter(self, lora: LoRAModel):
|
||||
self._create_merged_loras_inplace(lora)
|
||||
self._registered_adapters[lora.id] = lora
|
||||
self._set_long_lora_context(lora)
|
||||
|
||||
def pin_adapter(self, lora_id: int) -> bool:
|
||||
"""Pin a LoRAModel in the manager cache."""
|
||||
raise NotImplementedError(
|
||||
"Pinning is not supported in LoRAModelManager."
|
||||
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
|
||||
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
# update lora states
|
||||
self.punica_wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
self._registered_adapters.clear()
|
||||
self.lora_index_to_id = [None] * self.lora_slots
|
||||
self._active_adapters.clear()
|
||||
|
||||
def _create_lora_modules(self):
|
||||
for module_name, module in self.model.named_modules(
|
||||
remove_duplicate=False):
|
||||
if isinstance(module, PPMissingLayer):
|
||||
continue
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
# A temporary approach for multimodal models to support LoRA
|
||||
# TODO: Remove this restriction
|
||||
if self._filter_unsupported_mm_module(module_name):
|
||||
logger.warning(
|
||||
"Regarding multimodal models, vLLM currently only supports "
|
||||
"adding LoRA to language model, %s will be ignored.",
|
||||
module_name,
|
||||
)
|
||||
continue
|
||||
parts = module_name.split(".")[-1]
|
||||
packed_moduled_lst = self.packed_modules_mapping.get(parts, [])
|
||||
new_module = replace_submodule(
|
||||
self.model, module_name,
|
||||
from_layer(module, self.lora_slots, self.lora_config,
|
||||
packed_moduled_lst, self.model.config))
|
||||
|
||||
# LinearScalingRotaryEmbeddingWithLora is used to handle
|
||||
# long context lora. Register relevant metadata.
|
||||
if isinstance(new_module, LinearScalingRotaryEmbeddingWithLora):
|
||||
self.long_lora_context = LongContextLoRAContext(
|
||||
new_module.scaling_factors, new_module.rotary_dim)
|
||||
self.scaling_factor_to_offset = \
|
||||
new_module.scaling_factor_to_offset
|
||||
# (yard1): TODO make this more robust
|
||||
if "lm_head" in module_name:
|
||||
logits_processor_module = self.model.get_submodule(
|
||||
"logits_processor")
|
||||
new_module = replace_submodule(
|
||||
self.model, "logits_processor",
|
||||
from_layer_logits_processor(logits_processor_module,
|
||||
module, self.lora_slots,
|
||||
self.lora_config,
|
||||
self.model.config))
|
||||
|
||||
# In some models, especially multimodal ones, layers with the same
|
||||
# name may have different types, such as nn.Linear and
|
||||
# ReplicatedLinear. The nn.Linear layers cannot be replaced with
|
||||
# LoRA layers, leading to assertion error. The following check
|
||||
# aims to prevent this error
|
||||
if self.supports_mm and not isinstance(new_module,
|
||||
BaseLayerWithLoRA):
|
||||
continue
|
||||
self.register_module(module_name, new_module)
|
||||
self._register_packed_modules(module_name)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA)
|
||||
self.modules[module_name] = module
|
||||
|
||||
def create_dummy_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
rank: int,
|
||||
scaling_factor: Optional[float],
|
||||
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
||||
"""Create zero-initialized LoRAModel for warmup."""
|
||||
model = LoRAModel(lora_id, rank, {}, scaling_factor)
|
||||
for module_name, module in self.model.named_modules():
|
||||
bias_enabled = self.lora_config.bias_enabled
|
||||
if (not self._match_target_modules(module_name)
|
||||
or not isinstance(module, BaseLayerWithLoRA)
|
||||
or isinstance(module, LinearScalingRotaryEmbeddingWithLora)
|
||||
or self._filter_unsupported_mm_module(module_name)):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
if module_name not in self.packed_modules:
|
||||
assert embedding_modules is not None
|
||||
if parts[-1] in embedding_modules:
|
||||
input_dim = (module.base_layer.org_vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size if
|
||||
hasattr(module.base_layer, "org_vocab_size")
|
||||
else module.base_layer.weight.shape[1])
|
||||
output_dim = module.base_layer.embedding_dim if hasattr(
|
||||
module.base_layer,
|
||||
"embedding_dim") else module.base_layer.weight.shape[0]
|
||||
embeddings_tensor_dim = (module.base_layer.embedding_dim if
|
||||
hasattr(module.base_layer,
|
||||
"embedding_dim") else
|
||||
module.base_layer.weight.shape[1])
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
input_dim,
|
||||
output_dim,
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
embeddings_tensor_dim=embeddings_tensor_dim,
|
||||
bias_enabled=bias_enabled)
|
||||
else:
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
module.lora_a_stacked.shape[-1],
|
||||
module.lora_b_stacked.shape[-2],
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
bias_enabled=bias_enabled,
|
||||
)
|
||||
lora.optimize()
|
||||
else:
|
||||
parts = module_name.split(".")
|
||||
replacements = self.packed_modules_mapping[parts[-1]]
|
||||
subloras: List[Optional[LoRALayerWeights]] = []
|
||||
for i, r in enumerate(replacements):
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name + "." + r,
|
||||
module.lora_a_stacked[i].shape[-1],
|
||||
module.lora_b_stacked[i].shape[-2],
|
||||
rank,
|
||||
module.lora_a_stacked[i].dtype,
|
||||
"cpu",
|
||||
bias_enabled=bias_enabled,
|
||||
)
|
||||
lora.optimize()
|
||||
subloras.append(lora)
|
||||
lora = PackedLoRALayerWeights.pack(subloras)
|
||||
model.loras[module_name] = lora
|
||||
return model
|
||||
|
||||
def _match_target_modules(self, module_name: str):
|
||||
return any(
|
||||
re.match(
|
||||
r".*\.{target_module}$".format(target_module=target_module),
|
||||
module_name) or target_module == module_name
|
||||
for target_module in self.supported_lora_modules)
|
||||
|
||||
def _filter_unsupported_mm_module(self, module_name: str) -> bool:
|
||||
"""
|
||||
Regarding multimodal models, vLLM currently only supports adding LoRA to
|
||||
language model. LoRA for other modules, such as the vision tower, will
|
||||
be filtered out.
|
||||
"""
|
||||
if self.supports_mm:
|
||||
module_mapping: MultiModelKeys = self.model.get_mm_mapping()
|
||||
prefix_lst = module_mapping.connector + module_mapping.tower_model
|
||||
return any(
|
||||
[module_name.startswith(prefix) for prefix in prefix_lst])
|
||||
return False
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
module_name = parts[-1]
|
||||
replacements = self.packed_modules_mapping.get(module_name, [])
|
||||
# When replacements is less than or equal to 1, it indicates that this
|
||||
# module is not a packed module.
|
||||
if len(replacements) <= 1:
|
||||
return
|
||||
prefix = ".".join(parts[:-1])
|
||||
self.packed_modules[module_full_name] = [
|
||||
prefix + "." + r if prefix else r for r in replacements
|
||||
]
|
||||
|
||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||
for module_name, new_module_names in self.packed_modules.items():
|
||||
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
replacement_loras.append(lora)
|
||||
if lora:
|
||||
has_replacement = True
|
||||
if not has_replacement:
|
||||
continue
|
||||
for i in range(len(replacement_loras)):
|
||||
if replacement_loras[i]:
|
||||
continue
|
||||
replacement_loras[i] = None
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
self._deactivate_adapter)
|
||||
|
||||
def add_adapter(self, adapter: LoRAModel) -> bool:
|
||||
logger.debug(
|
||||
"Adding lora. Model id: %d, "
|
||||
"int id: %d, "
|
||||
"scaling factor: %s", adapter.id, adapter.id,
|
||||
adapter.scaling_factor)
|
||||
return add_adapter(adapter, self._registered_adapters, self.capacity,
|
||||
self._add_adapter)
|
||||
|
||||
def set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
self._last_mapping = set_adapter_mapping(mapping, self._last_mapping,
|
||||
self._set_adapter_mapping)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return remove_adapter(adapter_id, self._registered_adapters,
|
||||
self.deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, Any]:
|
||||
return list_adapters(self._registered_adapters)
|
||||
|
||||
def get_adapter(self, adapter_id: int) -> Optional[Any]:
|
||||
return get_adapter(adapter_id, self._registered_adapters)
|
||||
|
||||
|
||||
class LoRALRUCache(AdapterLRUCache[LoRAModel]):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int],
|
||||
bool]):
|
||||
super().__init__(capacity, deactivate_lora_fn)
|
||||
|
||||
|
||||
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
"""A model manager that manages multiple LoRAs with LRU cache."""
|
||||
|
||||
def __init__(self, model: nn.Module, max_num_seqs: int,
|
||||
max_num_batched_tokens: int, vocab_size: int,
|
||||
lora_config: LoRAConfig, device: torch.device):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config, device)
|
||||
self._registered_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_adapter)
|
||||
self._active_adapters: LoRALRUCache = LoRALRUCache(
|
||||
self.lora_slots, self._deactivate_adapter)
|
||||
|
||||
def list_adapters(self) -> Dict[int, LoRAModel]:
|
||||
"""List all registered LoRAModels."""
|
||||
return dict(self._registered_adapters.cache)
|
||||
|
||||
def add_adapter(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager."""
|
||||
logger.debug(
|
||||
"Adding lora. Model id: %d, "
|
||||
"int id: %d, "
|
||||
"scaling factor: %s", lora.id, lora.id, lora.scaling_factor)
|
||||
if lora.id not in self._registered_adapters:
|
||||
self._add_adapter(lora)
|
||||
was_added = True
|
||||
else:
|
||||
# We always touch to update the LRU cache order
|
||||
self._registered_adapters.touch(lora.id)
|
||||
was_added = False
|
||||
return was_added
|
||||
|
||||
def activate_adapter(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
if lora_id not in self._active_adapters and len(
|
||||
self._active_adapters) >= self.lora_slots:
|
||||
self._active_adapters.remove_oldest()
|
||||
result = super().activate_adapter(lora_id)
|
||||
# We always touch to update the LRU cache order
|
||||
self._active_adapters.touch(lora_id)
|
||||
return result
|
||||
|
||||
def remove_oldest_adapter(self) -> bool:
|
||||
if len(self._registered_adapters) > 0:
|
||||
self._registered_adapters.remove_oldest()
|
||||
return True
|
||||
return False
|
||||
|
||||
def pin_adapter(self, lora_id: int) -> bool:
|
||||
"""Pin a LoRAModel in the manager cache."""
|
||||
self._pin_lora_in_cpu_cache(lora_id)
|
||||
self._pin_lora_in_gpu_cache(lora_id)
|
||||
return True
|
||||
|
||||
def _pin_lora_in_cpu_cache(self, lora_id: int):
|
||||
try:
|
||||
self._registered_adapters.pin(lora_id)
|
||||
except ValueError as err:
|
||||
raise ValueError("Pinning failed. "
|
||||
f"LoRA {lora_id} is not registered.") from err
|
||||
|
||||
def _pin_lora_in_gpu_cache(self, lora_id: int):
|
||||
if lora_id not in self._active_adapters:
|
||||
# move lora to gpu if not already active
|
||||
self.activate_adapter(lora_id)
|
||||
|
||||
self._active_adapters.pin(lora_id)
|
||||
|
||||
|
||||
def create_lora_manager(
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
"""Create a LoRA adapter for a given model."""
|
||||
if not hasattr(model, "supported_lora_modules"):
|
||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||
lora_manager = lora_manager_cls(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
device=device,
|
||||
**kwargs)
|
||||
return lora_manager
|
||||
0
vllm-v0.6.2/vllm/lora/ops/__init__.py
Normal file
0
vllm-v0.6.2/vllm/lora/ops/__init__.py
Normal file
BIN
vllm-v0.6.2/vllm/lora/ops/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/ops/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm-v0.6.2/vllm/lora/ops/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm-v0.6.2/vllm/lora/ops/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
168
vllm-v0.6.2/vllm/lora/ops/bgmv_expand.py
Normal file
168
vllm-v0.6.2/vllm/lora/ops/bgmv_expand.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sn = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
|
||||
offset_k * xk_stride, ) # [BLOCK_K]
|
||||
else:
|
||||
tiled_a = tl.load(
|
||||
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
|
||||
mask=offset_k < K,
|
||||
other=0,
|
||||
) # [BLOCK_K]
|
||||
# N must be divisible by SPLIT_N
|
||||
split_n_length = tl.cdiv(N, SPLIT_N)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
# sliding to next row-block
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
pid_sn * split_n_length * lora_k_stride)
|
||||
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
|
||||
for n in range(0, split_n_length, BLOCK_N):
|
||||
current_n = n + offset_n
|
||||
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
|
||||
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
|
||||
< K)
|
||||
c_mask = current_n < split_n_length
|
||||
tiled_b = tl.load(
|
||||
b_ptr + current_n_c[:, None] * lora_k_stride +
|
||||
offset_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
||||
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_K = triton.next_power_of_2(K)
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
batches = lora_indices_tensor.size(0)
|
||||
config = get_lora_op_configs("expand", batches, N)
|
||||
grid = lambda META: (
|
||||
META["SPLIT_N"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_K=BLOCK_K,
|
||||
EVEN_K=EVEN_K,
|
||||
ADD_INPUTS=ADD_INPUTS,
|
||||
CAST_TYPE=CAST_TYPE,
|
||||
**config,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
bgmv_expand = torch.library.custom_op("lora::bgmv_expand",
|
||||
_bgmv_expand,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
bgmv_expand = _bgmv_expand
|
||||
181
vllm-v0.6.2/vllm/lora/ops/bgmv_expand_slice.py
Normal file
181
vllm-v0.6.2/vllm/lora/ops/bgmv_expand_slice.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_expand_slice_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sn = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
|
||||
offset_k * xk_stride, ) # [BLOCK_K]
|
||||
else:
|
||||
tiled_a = tl.load(
|
||||
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
|
||||
mask=offset_k < K,
|
||||
other=0,
|
||||
) # [BLOCK_K]
|
||||
# N must be divisible by SPLIT_N
|
||||
split_n_length = tl.cdiv(N, SPLIT_N)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
# sliding to next row-block
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
pid_sn * split_n_length * lora_k_stride)
|
||||
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
|
||||
slice_offset * cn_stride)
|
||||
|
||||
for n in range(0, split_n_length, BLOCK_N):
|
||||
current_n = n + offset_n
|
||||
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
|
||||
< K)
|
||||
c_mask = current_n < split_n_length
|
||||
tiled_b = tl.load(
|
||||
b_ptr + current_n[:, None] * lora_k_stride +
|
||||
offset_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
||||
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False.
|
||||
"""
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_K = triton.next_power_of_2(K)
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
batches = lora_indices_tensor.size(0)
|
||||
|
||||
config = get_lora_op_configs("expand", batches, N)
|
||||
|
||||
grid = lambda META: (
|
||||
META["SPLIT_N"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_expand_slice_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_K=BLOCK_K,
|
||||
EVEN_K=EVEN_K,
|
||||
ADD_INPUTS=ADD_INPUTS,
|
||||
CAST_TYPE=CAST_TYPE,
|
||||
**config,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice",
|
||||
_bgmv_expand_slice,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
bgmv_expand_slice = _bgmv_expand_slice
|
||||
150
vllm-v0.6.2/vllm/lora/ops/bgmv_shrink.py
Normal file
150
vllm-v0.6.2/vllm/lora/ops/bgmv_shrink.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sk = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
|
||||
a_ptr = input_ptr + cur_batch * xm_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index
|
||||
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_K * SPLIT_K):
|
||||
current_k = k + offset_k
|
||||
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
|
||||
tiled_a = tl.load(
|
||||
a_ptr + current_k_c,
|
||||
mask=current_k < K,
|
||||
other=0.0,
|
||||
) # [BLOCK_K]
|
||||
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
|
||||
|
||||
tiled_b = tl.load(
|
||||
b_ptr + offset_n[:, None] * lora_k_stride +
|
||||
current_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
accumulator += tl.sum(tiled_a * tiled_b, 1)
|
||||
accumulator *= scaling
|
||||
offset_cn = tl.arange(0, BLOCK_N)
|
||||
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
|
||||
c_mask = offset_cn < N
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
batches = lora_indices_tensor.size(0)
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
BLOCK_N = triton.next_power_of_2(N)
|
||||
# First try to load optimal config from the file
|
||||
config = get_lora_op_configs("bgmv_shrink", batches, K)
|
||||
|
||||
grid = lambda META: (
|
||||
META["SPLIT_K"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_N=BLOCK_N,
|
||||
**config,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink",
|
||||
_bgmv_shrink,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
bgmv_shrink = _bgmv_shrink
|
||||
201
vllm-v0.6.2/vllm/lora/ops/sgmv_expand.py
Normal file
201
vllm-v0.6.2/vllm/lora/ops/sgmv_expand.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's expand triton kernel is based on GroupGEMM.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride, )
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
_sgmv_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
sgmv_expand = torch.library.custom_op("lora::sgmv_expand",
|
||||
_sgmv_expand,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
sgmv_expand = _sgmv_expand
|
||||
214
vllm-v0.6.2/vllm/lora/ops/sgmv_expand_slice.py
Normal file
214
vllm-v0.6.2/vllm/lora/ops/sgmv_expand_slice.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_expand_slice_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Similar to the 'sgmv_expand' operator, but with an added parameter
|
||||
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
||||
might be that in the future, we could implement a fusion operator to
|
||||
achieve the current functionality instead of having to call it multiple
|
||||
times.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride, )
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
||||
(slice_offset + N))
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _sgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
) -> None:
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
slice_offset (int): output_tensor's offset
|
||||
slice_size (int): current output_tensor's size
|
||||
add_inputs (bool, optional): Defaults to False, adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
_sgmv_expand_slice_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice",
|
||||
_sgmv_expand_slice,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
sgmv_expand_slice = _sgmv_expand_slice
|
||||
198
vllm-v0.6.2/vllm/lora/ops/sgmv_shrink.py
Normal file
198
vllm-v0.6.2/vllm/lora/ops/sgmv_shrink.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sgmv_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride, # hidden_size
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
||||
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
||||
introducing SPLIT-K can improve performance
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_sk = tl.program_id(axis=1)
|
||||
cur_batch = tl.program_id(axis=2)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
|
||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride)
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +
|
||||
offset_k[:, None] * lora_n_stride)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :] < k_remaining,
|
||||
other=0.0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None] < k_remaining,
|
||||
other=0.0)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += BLOCK_K * SPLIT_K * xk_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g., if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence
|
||||
length of the sequences in the batch.
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences in the
|
||||
batch.
|
||||
token_nums (int): The token numbers in the batch. Used to verify if the
|
||||
token numbers in the inputs matches the one in the metadata.
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(0) == token_nums
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 32
|
||||
SPLIT_K = 8
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
SPLIT_K,
|
||||
batches,
|
||||
)
|
||||
|
||||
_sgmv_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink",
|
||||
_sgmv_shrink,
|
||||
mutates_args=["output_tensor"])
|
||||
except AttributeError:
|
||||
sgmv_shrink = _sgmv_shrink
|
||||
46
vllm-v0.6.2/vllm/lora/ops/utils.py
Normal file
46
vllm-v0.6.2/vllm/lora/ops/utils.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import functools
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _get_op_configs(op_type: str, batch: int, hidden_size: int):
|
||||
# TODO: add optimal configurations
|
||||
return None
|
||||
|
||||
|
||||
def _check_divisibility(hidden_size: int):
|
||||
# The bgmv_expand kernel requires that the hidden_size be divisible by
|
||||
# the number below.
|
||||
divisibility = [2, 4, 8, 16, 32, 64]
|
||||
divisibility.sort(reverse=True)
|
||||
for div in divisibility:
|
||||
if hidden_size % div == 0:
|
||||
return div
|
||||
# hidden_size is an odd number
|
||||
return 1
|
||||
|
||||
|
||||
def _get_default_config(op_type: str, batch: int, hidden_size: int):
|
||||
if op_type == "expand":
|
||||
return {
|
||||
"BLOCK_N": 256,
|
||||
"SPLIT_N": _check_divisibility(hidden_size),
|
||||
"num_warps": 8
|
||||
}
|
||||
else:
|
||||
return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8}
|
||||
|
||||
|
||||
def get_lora_op_configs(op_type: str, batch: int,
|
||||
hidden_size: int) -> Dict[str, int]:
|
||||
"""Inspired by `fused_moe_kernel`
|
||||
The return value will be a dictionary mapping an irregular grid of batch
|
||||
sizes and hidden_size to configurations of the bgmv-related kernel.
|
||||
NOTE: It currently only supports the default configuration. We plan to
|
||||
generate optimal configurations for different hardware in the future using
|
||||
scripts similar to `benchmark_moe.py`.
|
||||
"""
|
||||
config = _get_op_configs(op_type, batch, hidden_size)
|
||||
if not config:
|
||||
config = _get_default_config(op_type, batch, hidden_size)
|
||||
return config
|
||||
611
vllm-v0.6.2/vllm/lora/punica.py
Normal file
611
vllm-v0.6.2/vllm/lora/punica.py
Normal file
@@ -0,0 +1,611 @@
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
def compute_meta(
|
||||
token_lora_tensor: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]:
|
||||
"""
|
||||
Get the information required for the sgmv kernel. With the features:
|
||||
1. If consecutive requests in the batch use the same LoRA, this function
|
||||
will combine them into a single request, improving sgmv kernel inference
|
||||
performance.
|
||||
2. At the beginning of each prefill stage inference, recalculations are
|
||||
needed based on the input, but only once.
|
||||
"""
|
||||
|
||||
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
|
||||
token_lora_tensor, return_counts=True)
|
||||
cum_result = torch.cumsum(seq_length_tensor, dim=0)
|
||||
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
||||
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
||||
max_length = seq_length_tensor.max().item()
|
||||
token_nums = seq_length_tensor.sum().item()
|
||||
batch_size = lora_indices_tensor.size(0)
|
||||
no_lora = False
|
||||
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
||||
# the current step requires LoRA. If LoRA is not needed, the prefill stage
|
||||
# does not need to launch the triton kernel, which can improve performance
|
||||
if batch_size == 1 and lora_indices_tensor == -1:
|
||||
no_lora = True
|
||||
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||
batch_size, max_length, token_nums, no_lora)
|
||||
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def convert_mapping(
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
device: torch.device,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], List[int]]:
|
||||
"""Converts LoRAMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
||||
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
||||
max_loras: Maximum number of LoRAs.
|
||||
vocab_size: Model vocab size.
|
||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||
long_lora_context: Passed if there are long context lora in a batch.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors:
|
||||
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
LoRA indices.
|
||||
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
||||
LoRA indices for sampler. For generation, this will be the
|
||||
same as base_indicies. For prefill, this will map requests
|
||||
to LoRA indices.
|
||||
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
||||
requests to LoRA indices for sampler with padding.
|
||||
Same as sampler_indicies, but -1 is replaced with
|
||||
max_loras.
|
||||
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
||||
requests to embedding indices. First row is for embeddings
|
||||
added by the LoRAs, second row is for the LoRA.lora_a
|
||||
embeddings.
|
||||
long_lora_indices: Tensor of shape [batch_size] mapping
|
||||
requests to RoPE offsets and rot dims for long LoRAs.
|
||||
None if long context lora doesn't exist.
|
||||
indices_len: List of lengths of the above tensors. It contains
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, long_lora_indices).
|
||||
"""
|
||||
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
||||
embedding_indices = index_mapping_indices.copy()
|
||||
lora_indices = index_mapping_indices.copy()
|
||||
long_lora_offsets: Optional[torch.Tensor] = None
|
||||
if long_lora_context:
|
||||
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
prompt_mapping: List[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
for x in mapping.prompt_mapping
|
||||
]
|
||||
lora_idx = None
|
||||
for i in range(len(index_mapping_indices)):
|
||||
# TODO index can be slow. optimize
|
||||
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
||||
if index_mapping_indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||
lora_indices[i] = lora_idx
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
||||
index_mapping_indices[i], 0)
|
||||
long_lora_offsets[i] = lora_offset
|
||||
|
||||
indices_list: List[Union[List[int], torch.Tensor]] = [
|
||||
index_mapping_indices,
|
||||
lora_indices,
|
||||
embedding_indices,
|
||||
]
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
indices_list.append(long_lora_offsets)
|
||||
indices = torch.tensor(indices_list, dtype=torch.long, device=device)
|
||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
embeddings_indices = torch.stack([
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size),
|
||||
])
|
||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping_tensor
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = torch.arange(
|
||||
0, len(sampler_indices_padded), device=device, dtype=torch.long) + (
|
||||
sampler_indices_padded * len(sampler_indices_padded))
|
||||
long_lora_indices = None
|
||||
long_lora_indices_len: Optional[int] = None
|
||||
if long_lora_context:
|
||||
long_lora_indices = indices[3]
|
||||
long_lora_indices_len = long_lora_indices.shape[-1]
|
||||
# Contain length of indices tensors. Used to index into each tensor.
|
||||
indices_len = [
|
||||
base_indices.shape[-1],
|
||||
sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1],
|
||||
embeddings_indices.shape[-1],
|
||||
]
|
||||
if long_lora_indices_len is not None:
|
||||
indices_len.append(long_lora_indices_len)
|
||||
else:
|
||||
# If long_lora doesn't exist,append None
|
||||
indices_len.append(None)
|
||||
|
||||
return (
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_indices,
|
||||
indices_len,
|
||||
)
|
||||
|
||||
|
||||
class PunicaWrapper:
|
||||
"""
|
||||
PunicaWrapper is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica kernel.
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: Union[torch.device, str]):
|
||||
self._token_lora_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._sampler_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._embeddings_indices = torch.empty(2,
|
||||
max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._long_lora_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# 5 is the number of indicies tensors.
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices,long_lora_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 5
|
||||
# these attributes are the information required for sgmv kernel
|
||||
self._seq_start_locs = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._seq_lengths = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._lora_indices_per_batch = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.device: torch.device = device
|
||||
self.max_length: int = 0
|
||||
self.token_nums: int = 0
|
||||
self.batch_size: int = -1
|
||||
self.is_prefill = False
|
||||
self.no_lora = False
|
||||
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||
vocab_size, extra_vocab_size,
|
||||
long_lora_context)
|
||||
if mapping.is_prefill:
|
||||
# Update metadata required for prefill-related operators.
|
||||
self._update_prefill_metada(self.token_lora_indices)
|
||||
self.is_prefill = True
|
||||
else:
|
||||
self.is_prefill = False
|
||||
|
||||
def _update_base_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_offsets_tensor,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
lora_index_to_id,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
self.device,
|
||||
long_lora_context,
|
||||
)
|
||||
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||
sampler_indices_padded)
|
||||
self._embeddings_indices[:embeddings_indices.
|
||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||
embeddings_indices)
|
||||
if long_lora_offsets_tensor is not None:
|
||||
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
||||
long_lora_offsets_tensor)
|
||||
else:
|
||||
self._long_lora_indices.zero_()
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
|
||||
|
||||
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||
batch_size, max_length, token_nums,
|
||||
no_lora) = compute_meta(token_lora_tensor)
|
||||
|
||||
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
|
||||
b_seq_start_tensor)
|
||||
self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
|
||||
self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
|
||||
lora_indices_tensor)
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
self.token_nums = token_nums
|
||||
self.no_lora = no_lora
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int]:
|
||||
"""
|
||||
This property provides a convenient way to access the necessary
|
||||
metadata for prefill-related kernel computations.
|
||||
1. seq_start_locs: Tensor of sequence start positions.
|
||||
2. seq_lengths: Tensor of sequence lengths.
|
||||
3. lora_indices_per_batch: Tensor of lora indices, and an index of
|
||||
-1 means no lora should be applied.
|
||||
4. batch_size: Batch size after clustering identical lora indices.
|
||||
5. max_length: The maximum sequence length in the batch.
|
||||
6. token_nums: The token numbers in the batch.
|
||||
"""
|
||||
return (self._seq_start_locs[:self.batch_size],
|
||||
self._seq_lengths[:self.batch_size],
|
||||
self._lora_indices_per_batch[:self.batch_size],
|
||||
self.batch_size, self.max_length, self.token_nums)
|
||||
|
||||
@property
|
||||
def token_lora_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides the lora indices corresponding to each token
|
||||
in the batch. An index of -1 means no lora should be applied.
|
||||
"""
|
||||
token_lora_len = self.indices_len[0]
|
||||
return self._token_lora_indices[:token_lora_len]
|
||||
|
||||
@property
|
||||
def sampler_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property is used to access the lora indices specifically for
|
||||
LogitsProcessorWithLoRA.
|
||||
"""
|
||||
sampler_indices_len = self.indices_len[1]
|
||||
return self._sampler_indices[:sampler_indices_len]
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices.
|
||||
"""
|
||||
indices_padded_len = self.indices_len[2]
|
||||
return self._sampler_indices_padded[:indices_padded_len]
|
||||
|
||||
@property
|
||||
def embeddings_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for lora embeddings,
|
||||
specifically for VocabParallelEmbeddingWithLoRA.
|
||||
"""
|
||||
embeddings_indices_len = self.indices_len[3]
|
||||
return self._embeddings_indices[:, :embeddings_indices_len]
|
||||
|
||||
@property
|
||||
def long_lora_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for long context
|
||||
lora, specifically for LinearScalingRotaryEmbeddingWithLora.
|
||||
"""
|
||||
long_lora_len = self.indices_len[4]
|
||||
return self._long_lora_indices[:long_lora_len]
|
||||
|
||||
def shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||
|
||||
def expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_input: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_input,
|
||||
)
|
||||
|
||||
def expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_input: bool,
|
||||
):
|
||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
|
||||
|
||||
def expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_input,
|
||||
)
|
||||
|
||||
def expand_slice_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool,
|
||||
):
|
||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||
y_slice_size, add_input)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the shrink_decode function
|
||||
should be called.
|
||||
"""
|
||||
shrink_fun: Callable = (self.shrink_prefill
|
||||
if self.is_prefill else self.shrink_decode)
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_input: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
When `is_prefill` is true, it indicates that it is currently the
|
||||
prefill stage, and the `expand_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the expand_decode function
|
||||
should be called.
|
||||
"""
|
||||
|
||||
expand_fun: Callable = (self.expand_prefill
|
||||
if self.is_prefill else self.expand_decode)
|
||||
expand_fun(y, x, w_t_all, add_input)
|
||||
|
||||
def add_expand_slice(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool = True):
|
||||
"""
|
||||
Similar to `add_expand`
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (self.expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self.expand_slice_decode)
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
||||
|
||||
def add_lora(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
y_offset: Optional[int] = None,
|
||||
y_slice_size: Optional[int] = None,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
wa_t_all (torch.Tensor): lora_a's weight
|
||||
wb_t_all (torch.Tensor): lora_b's weight
|
||||
scale (float): Scaling factor.
|
||||
y_offset (Optional[int], optional): Offset to apply to the starting
|
||||
column of y.
|
||||
y_slice_size (Optional[int], optional): Size of the y column slice.
|
||||
buffer (Optional[torch.Tensor], optional): Defaults to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default ,refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
self.add_shrink(buffer, x, wa_t_all, scale)
|
||||
if y_offset is None and y_slice_size is None:
|
||||
self.add_expand(y, buffer, wb_t_all, add_input=True)
|
||||
else:
|
||||
self.add_expand_slice(y,
|
||||
buffer,
|
||||
wb_t_all,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_input=True)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...]) -> None:
|
||||
"""
|
||||
Applies lora to each input. Similar to add_lora, This method is
|
||||
used for layers that are composed of multiple sublayers
|
||||
(slices) packed together.
|
||||
"""
|
||||
y_org = y
|
||||
x = x.view(-1, x.shape[-1])
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = 0
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(output_slices)):
|
||||
self.add_lora(y, x, lora_a_stacked[slice_idx],
|
||||
lora_b_stacked[slice_idx], scale, offset_left,
|
||||
output_slices[slice_idx])
|
||||
offset_left += output_slices[slice_idx]
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
LogitsProcessorWithLoRA always using bgmv
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default ,refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
|
||||
y = y.view_as(y_org)
|
||||
95
vllm-v0.6.2/vllm/lora/request.py
Normal file
95
vllm-v0.6.2/vllm/lora/request.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.adapter_commons.request import AdapterRequest
|
||||
|
||||
|
||||
class LoRARequest(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
"""
|
||||
Request for a LoRA adapter.
|
||||
|
||||
Note that this class should be used internally. For online
|
||||
serving, it is recommended to not allow users to use this class but
|
||||
instead provide another layer of abstraction to prevent users from
|
||||
accessing unauthorized LoRA adapters.
|
||||
|
||||
lora_int_id must be globally unique for a given adapter.
|
||||
This is currently not enforced in vLLM.
|
||||
"""
|
||||
__metaclass__ = AdapterRequest
|
||||
|
||||
lora_name: str
|
||||
lora_int_id: int
|
||||
lora_path: str = ""
|
||||
lora_local_path: Optional[str] = msgspec.field(default=None)
|
||||
long_lora_max_len: Optional[int] = None
|
||||
base_model_name: Optional[str] = msgspec.field(default=None)
|
||||
|
||||
def __post_init__(self):
|
||||
if 'lora_local_path' in self.__struct_fields__:
|
||||
warnings.warn(
|
||||
"The 'lora_local_path' attribute is deprecated "
|
||||
"and will be removed in a future version. "
|
||||
"Please use 'lora_path' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
if not self.lora_path:
|
||||
self.lora_path = self.lora_local_path or ""
|
||||
|
||||
# Ensure lora_path is not empty
|
||||
assert self.lora_path, "lora_path cannot be empty"
|
||||
|
||||
@property
|
||||
def adapter_id(self):
|
||||
return self.lora_int_id
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.lora_name
|
||||
|
||||
@property
|
||||
def path(self):
|
||||
return self.lora_path
|
||||
|
||||
@property
|
||||
def local_path(self):
|
||||
warnings.warn(
|
||||
"The 'local_path' attribute is deprecated "
|
||||
"and will be removed in a future version. "
|
||||
"Please use 'path' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
return self.lora_path
|
||||
|
||||
@local_path.setter
|
||||
def local_path(self, value):
|
||||
warnings.warn(
|
||||
"The 'local_path' attribute is deprecated "
|
||||
"and will be removed in a future version. "
|
||||
"Please use 'path' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2)
|
||||
self.lora_path = value
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
"""
|
||||
Overrides the equality method to compare LoRARequest
|
||||
instances based on lora_name. This allows for identification
|
||||
and comparison lora adapter across engines.
|
||||
"""
|
||||
return isinstance(value,
|
||||
self.__class__) and self.lora_name == value.lora_name
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""
|
||||
Overrides the hash method to hash LoRARequest instances
|
||||
based on lora_name. This ensures that LoRARequest instances
|
||||
can be used in hash-based collections such as sets and dictionaries,
|
||||
identified by their names across engines.
|
||||
"""
|
||||
return hash(self.lora_name)
|
||||
192
vllm-v0.6.2/vllm/lora/utils.py
Normal file
192
vllm-v0.6.2/vllm/lora/utils.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||
HFValidationError, RepositoryNotFoundError)
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.fully_sharded_layers import (
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora, QKVParallelLinearWithShardedLora,
|
||||
RowParallelLinearWithShardedLoRA)
|
||||
# being imported for _all_lora_classes below
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLora,
|
||||
LogitsProcessorWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
ReplicatedLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = {
|
||||
VocabParallelEmbeddingWithLoRA,
|
||||
ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithLora,
|
||||
MergedQKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA,
|
||||
ReplicatedLinearWithLoRA,
|
||||
LogitsProcessorWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithShardedLora,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithShardedLora,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
LinearScalingRotaryEmbeddingWithLora,
|
||||
}
|
||||
|
||||
|
||||
def from_layer(layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig] = None) -> nn.Module:
|
||||
for lora_cls in _all_lora_classes:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
if lora_cls.can_replace_layer(source_layer=layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config):
|
||||
ret = lora_cls(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
|
||||
def from_layer_logits_processor(
|
||||
layer: LogitsProcessor,
|
||||
lm_head: ParallelLMHead,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> LogitsProcessorWithLoRA:
|
||||
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
||||
lm_head.weight.dtype, lm_head.weight.device,
|
||||
lm_head.get_sharded_to_full_mapping())
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
|
||||
|
||||
def replace_submodule(model: nn.Module, module_name: str,
|
||||
new_module: nn.Module) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
|
||||
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool, bool]:
|
||||
"""Parse the name of lora weights.
|
||||
|
||||
args:
|
||||
name: the name of the fine-tuned LoRA, e.g.
|
||||
base_model.model.dense1.weight
|
||||
return:
|
||||
Tuple(module_name, is_lora_a):
|
||||
module_name: the name of the module, e.g. model.dense1,
|
||||
is_lora_a whether the tensor is lora_a or lora_b.
|
||||
is_bias whether the tensor is lora bias.
|
||||
"""
|
||||
parts = name.split(".")
|
||||
if parts[-1] == "weight" and (parts[-2] == "lora_A"
|
||||
or parts[-2] == "lora_B"):
|
||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A", False
|
||||
|
||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A", False
|
||||
|
||||
if parts[-1] == "bias":
|
||||
return ".".join(parts[2:-2]), False, True
|
||||
|
||||
raise ValueError(f"{name} is unsupported LoRA weight")
|
||||
|
||||
|
||||
def is_regex_target_modules(load_modules: Union[str, List[str]],
|
||||
expected_lora_modules: List[str]) -> bool:
|
||||
"""
|
||||
PEFT supports passing `target_modules` in the form of regular expressions,
|
||||
such as `model.*(q_proj|k_proj|v_proj)$`. This function is mainly used to
|
||||
determine whether the suffix in the regular expression is present in the
|
||||
`expected_lora_modules`.
|
||||
"""
|
||||
|
||||
def is_valid_regex(pattern):
|
||||
try:
|
||||
re.compile(pattern)
|
||||
return True
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
def is_subset(sub_list, full_list):
|
||||
return set(sub_list).issubset(set(full_list))
|
||||
|
||||
# Similar to PEFT's processing logic, regex-related operations are only
|
||||
# executed when the load_modules is a `str`.
|
||||
if not isinstance(load_modules, str):
|
||||
return False
|
||||
|
||||
if is_valid_regex(load_modules):
|
||||
match = re.search(r"\((.*?)\)\$?$", load_modules)
|
||||
if match:
|
||||
suffix = match.group(1).split("|")
|
||||
return is_subset(suffix, expected_lora_modules)
|
||||
return False
|
||||
|
||||
|
||||
def get_adapter_absolute_path(lora_path: str) -> str:
|
||||
"""
|
||||
Resolves the given lora_path to an absolute local path.
|
||||
|
||||
If the lora_path is identified as a Hugging Face model identifier,
|
||||
it will download the model and return the local snapshot path.
|
||||
Otherwise, it treats the lora_path as a local file path and
|
||||
converts it to an absolute path.
|
||||
|
||||
Parameters:
|
||||
lora_path (str): The path to the lora model, which can be an absolute path,
|
||||
a relative path, or a Hugging Face model identifier.
|
||||
|
||||
Returns:
|
||||
str: The resolved absolute local path to the lora model.
|
||||
"""
|
||||
|
||||
# Check if the path is an absolute path. Return it no matter exists or not.
|
||||
if os.path.isabs(lora_path):
|
||||
return lora_path
|
||||
|
||||
# If the path starts with ~, expand the user home directory.
|
||||
if lora_path.startswith('~'):
|
||||
return os.path.expanduser(lora_path)
|
||||
|
||||
# Check if the expanded relative path exists locally.
|
||||
if os.path.exists(lora_path):
|
||||
return os.path.abspath(lora_path)
|
||||
|
||||
# If the path does not exist locally, assume it's a Hugging Face repo.
|
||||
try:
|
||||
local_snapshot_path = huggingface_hub.snapshot_download(
|
||||
repo_id=lora_path)
|
||||
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
|
||||
HFValidationError):
|
||||
# Handle errors that may occur during the download
|
||||
# Return original path instead instead of throwing error here
|
||||
logger.exception("Error downloading the HuggingFace model")
|
||||
return lora_path
|
||||
|
||||
return local_snapshot_path
|
||||
214
vllm-v0.6.2/vllm/lora/worker_manager.py
Normal file
214
vllm-v0.6.2/vllm/lora/worker_manager.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.adapter_commons.utils import (add_adapter_worker,
|
||||
apply_adapters_worker,
|
||||
list_adapters_worker,
|
||||
set_active_adapters_worker)
|
||||
from vllm.adapter_commons.worker_manager import AbstractWorkerManager
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class WorkerLoRAManager(AbstractWorkerManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Every request, the requested LoRAs will be loaded (unless they are already
|
||||
loaded), and every other LoRA will be unloaded."""
|
||||
|
||||
_manager_cls: Type[LoRAModelManager] = LoRAModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
embedding_modules: Dict[str, str],
|
||||
embedding_padding_modules: List[str],
|
||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||
max_position_embeddings: Optional[int] = None,
|
||||
):
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.lora_config = lora_config
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
super().__init__(device)
|
||||
# Lazily initialized by create_lora_manager.
|
||||
self._adapter_manager: LoRAModelManager
|
||||
|
||||
@contextmanager
|
||||
def dummy_lora_cache(self):
|
||||
"""Use this context manager to reuse the dummy lora model
|
||||
to avoid creating it repeatedly."""
|
||||
self._cached_dummy_lora = None
|
||||
yield
|
||||
self._cached_dummy_lora = False
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
|
||||
try:
|
||||
model = self._adapter_manager.model
|
||||
supported_lora_modules = model.supported_lora_modules
|
||||
packed_modules_mapping = model.packed_modules_mapping
|
||||
expected_lora_modules: List[str] = []
|
||||
for module in supported_lora_modules:
|
||||
if module in packed_modules_mapping:
|
||||
expected_lora_modules.extend(
|
||||
packed_modules_mapping[module])
|
||||
else:
|
||||
expected_lora_modules.append(module)
|
||||
lora_path = get_adapter_absolute_path(lora_request.lora_path)
|
||||
lora = self._lora_model_cls.from_local_checkpoint(
|
||||
lora_path,
|
||||
expected_lora_modules,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
lora_model_id=lora_request.lora_int_id,
|
||||
device="cpu",
|
||||
dtype=self.lora_config.lora_dtype,
|
||||
target_embedding_padding=self.vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
embedding_modules=self.embedding_modules,
|
||||
embedding_padding_modules=self.embedding_padding_modules,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Loading lora {lora_path} failed") from e
|
||||
if lora.rank > self.lora_config.max_lora_rank:
|
||||
raise ValueError(
|
||||
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
||||
f"{self.lora_config.max_lora_rank}.")
|
||||
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
|
||||
raise ValueError(f"LoRA added vocab size {lora.extra_vocab_size} "
|
||||
f"is greater than lora_extra_vocab_size "
|
||||
f"{self.lora_config.lora_extra_vocab_size}.")
|
||||
return lora
|
||||
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
if lora_request.lora_int_id in self.list_adapters():
|
||||
return False
|
||||
if isinstance(self._cached_dummy_lora, LoRAModel):
|
||||
dummy_lora = self._cached_dummy_lora.clone(
|
||||
lora_request.lora_int_id)
|
||||
else:
|
||||
dummy_lora = self._adapter_manager.create_dummy_lora(
|
||||
lora_request.lora_int_id, rank, 1, self.embedding_modules)
|
||||
if self._cached_dummy_lora is None:
|
||||
self._cached_dummy_lora = dummy_lora
|
||||
return self._adapter_manager.add_adapter(dummy_lora)
|
||||
|
||||
def pin_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.pin_adapter(adapter_id)
|
||||
|
||||
def set_active_adapters(self, requests: Set[Any],
|
||||
mapping: Optional[Any]) -> None:
|
||||
set_active_adapters_worker(requests, mapping, self._apply_adapters,
|
||||
self._adapter_manager.set_adapter_mapping)
|
||||
|
||||
def _apply_adapters(self, adapter_requests: Set[Any]) -> None:
|
||||
apply_adapters_worker(adapter_requests, self.list_adapters,
|
||||
self._adapter_manager.adapter_slots,
|
||||
self.remove_adapter, self.add_adapter)
|
||||
|
||||
def add_adapter(self, adapter_request: Any) -> bool:
|
||||
return add_adapter_worker(adapter_request, self.list_adapters,
|
||||
self._load_adapter,
|
||||
self._adapter_manager.add_adapter,
|
||||
self._adapter_manager.activate_adapter)
|
||||
|
||||
def remove_adapter(self, adapter_id: int) -> bool:
|
||||
return self._adapter_manager.remove_adapter(adapter_id)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
self._adapter_manager.remove_all_adapters()
|
||||
|
||||
def list_adapters(self) -> Set[int]:
|
||||
return list_adapters_worker(self._adapter_manager.list_adapters)
|
||||
|
||||
|
||||
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
|
||||
(unless they are already loaded) and least recently used LoRAs will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
|
||||
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
lora_manager_cls=self._manager_cls,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
device=self.device,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
)
|
||||
self._adapter_manager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None:
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
}
|
||||
if len(loras_map) > self._adapter_manager.lora_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||
"than the number of GPU LoRA slots "
|
||||
f"({self._adapter_manager.lora_slots}).")
|
||||
for lora in loras_map.values():
|
||||
self.add_adapter(lora)
|
||||
|
||||
def add_adapter(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id not in self.list_adapters():
|
||||
# Remove before we load the new lora to save memory
|
||||
if len(self._adapter_manager) + 1 > self._adapter_manager.capacity:
|
||||
assert isinstance(self._adapter_manager,
|
||||
LRUCacheLoRAModelManager)
|
||||
self._adapter_manager.remove_oldest_adapter()
|
||||
lora = self._load_adapter(lora_request)
|
||||
loaded = self._adapter_manager.add_adapter(lora)
|
||||
else:
|
||||
# If the lora is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._adapter_manager.get_adapter(
|
||||
lora_request.lora_int_id) is not None
|
||||
self._adapter_manager.activate_adapter(lora_request.lora_int_id)
|
||||
return loaded
|
||||
Reference in New Issue
Block a user