update
This commit is contained in:
@@ -1,41 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.layers.column_parallel_linear import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.fused_moe import FusedMoEWithLoRA
|
||||
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
|
||||
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import LoRAMapping
|
||||
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
|
||||
|
||||
__all__ = [
|
||||
"BaseLayerWithLoRA",
|
||||
"VocabParallelEmbeddingWithLoRA",
|
||||
"LogitsProcessorWithLoRA",
|
||||
"ColumnParallelLinearWithLoRA",
|
||||
"ColumnParallelLinearWithShardedLoRA",
|
||||
"MergedColumnParallelLinearWithLoRA",
|
||||
"MergedColumnParallelLinearWithShardedLoRA",
|
||||
"MergedQKVParallelLinearWithLoRA",
|
||||
"MergedQKVParallelLinearWithShardedLoRA",
|
||||
"QKVParallelLinearWithLoRA",
|
||||
"QKVParallelLinearWithShardedLoRA",
|
||||
"RowParallelLinearWithLoRA",
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"FusedMoEWithLoRA",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,67 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.punica_wrapper import PunicaWrapperBase
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def slice_lora_a(
|
||||
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora a if splitting for tensor parallelism."""
|
||||
...
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora b if splitting with tensor parallelism."""
|
||||
...
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
...
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
...
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
punica_wrapper,
|
||||
):
|
||||
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
raise NotImplementedError
|
||||
@@ -1,164 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import BaseLayerWithLoRA
|
||||
from .utils import _get_lora_device
|
||||
|
||||
|
||||
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: LinearBase):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.input_size = self.base_layer.input_size
|
||||
# Ensure tp_size and tp_rank consistency with the base_layer.
|
||||
self.tp_size = self.base_layer.tp_size
|
||||
self.tp_rank = self.base_layer.tp_rank
|
||||
self.device = _get_lora_device(self.base_layer)
|
||||
self.output_slices: tuple[int, ...]
|
||||
self.output_size: int
|
||||
self.n_slices: int
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
#
|
||||
if isinstance(self.base_layer, ReplicatedLinear):
|
||||
lora_a_out_size = lora_config.max_lora_rank
|
||||
lora_b_out_size = self.output_size
|
||||
|
||||
elif isinstance(self.base_layer, ColumnParallelLinear):
|
||||
lora_a_out_size = (
|
||||
lora_config.max_lora_rank
|
||||
if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size)
|
||||
)
|
||||
lora_b_out_size = self.output_size
|
||||
|
||||
elif isinstance(self.base_layer, RowParallelLinear):
|
||||
lora_a_out_size = lora_config.max_lora_rank
|
||||
lora_b_out_size = (
|
||||
self.output_size
|
||||
if not lora_config.fully_sharded_loras
|
||||
else divide(self.output_size, self.tp_size)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_out_size,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.n_slices)
|
||||
)
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_b_out_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.n_slices)
|
||||
)
|
||||
self.output_slices = (self.lora_b_stacked[0].shape[2],)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
for s_index in range(self.n_slices):
|
||||
self.lora_a_stacked[s_index][index] = 0
|
||||
self.lora_b_stacked[s_index][index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
# Except for QKVParallelLinearWithLoRA and
|
||||
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
||||
# store weights in a tuple of size 1. These two layers will
|
||||
# override this function.
|
||||
assert (
|
||||
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
|
||||
)
|
||||
|
||||
self.reset_lora(index)
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True
|
||||
)
|
||||
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
# In Transformers modeling backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
if x.ndim == 3 and output.ndim == 3:
|
||||
output = output.flatten(0, 1)
|
||||
x = x.flatten(0, 1)
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def weight(self) -> torch.Tensor:
|
||||
# unquantizedLinear
|
||||
if hasattr(self.base_layer, "weight"):
|
||||
return self.base_layer.weight
|
||||
# Compressed Tensor
|
||||
elif hasattr(self.base_layer, "weight_packed"):
|
||||
return self.base_layer.weight_packed
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(self.base_layer, "qweight"):
|
||||
return self.base_layer.qweight
|
||||
# marlin
|
||||
elif hasattr(self.base_layer, "B"):
|
||||
return self.base_layer.B
|
||||
# HQQ marlin
|
||||
elif hasattr(self.base_layer, "W_q"):
|
||||
return self.base_layer.W_q
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {self.base_layer}")
|
||||
|
||||
@property
|
||||
def bias(self) -> torch.Tensor | None:
|
||||
if hasattr(self.base_layer, "bias"):
|
||||
return self.base_layer.bias
|
||||
else:
|
||||
return None
|
||||
@@ -1,578 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
|
||||
|
||||
|
||||
def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
|
||||
"""
|
||||
For `ColumnParallelLinearWithLoRA` or classes that inherit from
|
||||
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
|
||||
"""
|
||||
assert (
|
||||
layer.n_slices
|
||||
== len(layer.lora_a_stacked)
|
||||
== len(layer.lora_b_stacked)
|
||||
== len(layer.output_slices)
|
||||
)
|
||||
|
||||
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
|
||||
# Since communication is needed, the buffer is directly initialized as a
|
||||
# tensor rather than a tuple of tensor.
|
||||
buffers = torch.zeros(
|
||||
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink(
|
||||
buffers, x, layer.lora_a_stacked, 1.0
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
buffers = shrunk_buffers
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
|
||||
lora_output: torch.Tensor | None = layer.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffers,
|
||||
layer.lora_b_stacked,
|
||||
layer.output_slices,
|
||||
offset_start=0,
|
||||
add_input=True,
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
# now have column partitioned and packed output
|
||||
return output
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
"""
|
||||
LoRA on top of ColumnParallelLinear layer.
|
||||
LoRA B is sliced for tensor parallelism.
|
||||
There are two types for the `base_layer`:
|
||||
1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
|
||||
2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
# The base_layer type is ColumnParallelLinear or
|
||||
# MergedColumnParallelLinear, their weight sharding logic is
|
||||
# inconsistent when TP is greater than 1.
|
||||
self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear
|
||||
self.output_size = self.base_layer.output_size_per_partition
|
||||
# There is only one LoRA layer
|
||||
self.n_slices = 1
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
# Applicable to cases where the base_layer is
|
||||
# MergedColumnParallelLinear.
|
||||
if self.is_merged_col_linear:
|
||||
shard_size = self.output_size // 2
|
||||
offset = lora_b.shape[0] // 2
|
||||
|
||||
left_weight = lora_b[
|
||||
self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
|
||||
]
|
||||
right_weight = lora_b[
|
||||
offset + self.tp_rank * shard_size : offset
|
||||
+ (self.tp_rank + 1) * shard_size,
|
||||
:,
|
||||
]
|
||||
lora_b = torch.cat([left_weight, right_weight], dim=0)
|
||||
# Applicable to cases where the base_layer is
|
||||
# ColumnParallelLinear.
|
||||
else:
|
||||
shard_size = self.output_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[start_idx:end_idx, :]
|
||||
return lora_b
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply(input_, bias)
|
||||
if self.base_layer.gather_output and self.tp_size > 1:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ColumnParallelLinear or (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
and len(packed_modules_list) == 1
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
|
||||
packed together (e.g. gate_proj + up_proj -> gate_up_proj).
|
||||
|
||||
This means we have 2 LoRAs, each applied to one half of the layer.
|
||||
|
||||
Both slices must have the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
|
||||
) -> None:
|
||||
super().__init__(base_layer)
|
||||
# There are two LoRA layers
|
||||
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
|
||||
# we need to divide it by the tp_size to get correct slices size
|
||||
output_sizes = self.base_layer.output_sizes
|
||||
self.output_slices = tuple(
|
||||
divide(output_size, self.tp_size) for output_size in output_sizes
|
||||
)
|
||||
self.n_slices = len(self.output_slices)
|
||||
self.output_ids = (self.tp_rank,) * self.n_slices
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
The main reason for overriding this function is to enhance code
|
||||
maintainability.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
|
||||
lora_a_output_size_per_partition = (
|
||||
lora_config.max_lora_rank
|
||||
if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size)
|
||||
)
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.n_slices)
|
||||
)
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
output_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for output_size in self.output_slices
|
||||
)
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
sliced_lora_b = [None] * self.n_slices
|
||||
for i, (shard_id, shard_size) in enumerate(
|
||||
zip(self.output_ids, self.output_slices)
|
||||
):
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
sliced_lora_b[i] = lora_b_i[
|
||||
shard_size * shard_id : shard_size * (shard_id + 1), :
|
||||
]
|
||||
return sliced_lora_b
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
for i in range(self.n_slices):
|
||||
if (lora_a_i := lora_a[i]) is not None:
|
||||
self.lora_a_stacked[i][
|
||||
index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
|
||||
].copy_(lora_a_i, non_blocking=True)
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
self.lora_b_stacked[i][
|
||||
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
|
||||
].copy_(lora_b_i, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
and len(packed_modules_list) == 2
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
ColumnParallelLinear layer that is specifically designed for
|
||||
qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
|
||||
only contains a single LoRA within their qkv_proj layer.
|
||||
|
||||
During inference with Tensor Parallel, the weights of lora_b
|
||||
must be accurately partitioned according to the respective ranks.
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
self.q_proj_total_size = (
|
||||
self.base_layer.total_num_heads * self.base_layer.head_size
|
||||
)
|
||||
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
|
||||
self.kv_proj_shard_size = (
|
||||
self.base_layer.num_kv_heads * self.base_layer.head_size
|
||||
)
|
||||
self.kv_proj_total_size = (
|
||||
self.base_layer.total_num_kv_heads * self.base_layer.head_size
|
||||
)
|
||||
# There is only one LoRA layer
|
||||
self.n_slices = 1
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
lora_b_q = lora_b[
|
||||
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
|
||||
* (self.q_shard_id + 1),
|
||||
:,
|
||||
]
|
||||
k_offset = self.q_proj_total_size
|
||||
lora_b_k = lora_b[
|
||||
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
|
||||
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
|
||||
:,
|
||||
]
|
||||
v_offset = k_offset + self.kv_proj_total_size
|
||||
lora_b_v = lora_b[
|
||||
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
|
||||
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
|
||||
:,
|
||||
]
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
|
||||
return lora_b
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||
packed together in qkv proj fashion
|
||||
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||
|
||||
This means we have 3 LoRAs, each applied to one slice of the layer.
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
# There are three LoRA layer.
|
||||
self.n_slices = len(self.base_layer.output_sizes)
|
||||
|
||||
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
|
||||
self.kv_proj_shard_size = (
|
||||
self.base_layer.num_kv_heads * self.base_layer.head_size
|
||||
)
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
self.output_slices = (
|
||||
self.q_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
)
|
||||
self.output_ids = (
|
||||
self.q_shard_id,
|
||||
self.kv_shard_id,
|
||||
self.kv_shard_id,
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
The main reason for overloading this function is to handle inconsistent
|
||||
weight dimensions in qkv lora.
|
||||
"""
|
||||
super().create_lora_weights(max_loras, lora_config, model_config)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
|
||||
|
||||
|
||||
# These following layers are based on the tensor parallelism strategy given in
|
||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||
# https://arxiv.org/abs/2311.03285.
|
||||
|
||||
|
||||
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
|
||||
# their `lora_a` and `lora_b` have different sharding patterns. After
|
||||
# completing the `lora_a` GEMM , a gather operation is performed.
|
||||
# Therefore, the sharding of `lora_a` only needs to correspond with the
|
||||
# gather operation.
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
lora_a = lora_a[start_idx : start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
# NOTE: lora_a contains 2 subloras, and each sublora could be None.
|
||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||
output_start_idx = self.tp_rank * output_shard_size
|
||||
lora_a = [
|
||||
lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
|
||||
if lora_a[0] is not None
|
||||
else None,
|
||||
lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
|
||||
if lora_a[1] is not None
|
||||
else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from QKVParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
lora_a = lora_a[start_idx : start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedQKVParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
|
||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||
lora_a = [
|
||||
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
|
||||
if lora_a[0] is not None
|
||||
else None,
|
||||
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
|
||||
if lora_a[1] is not None
|
||||
else None,
|
||||
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
|
||||
if lora_a[2] is not None
|
||||
else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
@@ -1,472 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
modular_marlin_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
modular_triton_fused_moe,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
|
||||
|
||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: FusedMoE) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
assert not self.base_layer.use_ep, (
|
||||
"EP support for Fused MoE LoRA is not implemented yet."
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.device = base_layer.w2_weight.device
|
||||
self._inject_lora_into_fused_moe()
|
||||
|
||||
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
|
||||
normalized_config = {}
|
||||
for key, value in config.items():
|
||||
if key.islower():
|
||||
if key.startswith("block_"):
|
||||
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
|
||||
else:
|
||||
normalized_key = key.upper()
|
||||
else:
|
||||
normalized_key = key
|
||||
normalized_config[normalized_key] = value
|
||||
return normalized_config
|
||||
|
||||
def _get_lora_moe_configs(
|
||||
self,
|
||||
op_prefix: str,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
num_slices: int,
|
||||
M: int,
|
||||
layer: FusedMoE,
|
||||
top_k: int,
|
||||
config_dtype: str,
|
||||
):
|
||||
if envs.VLLM_TUNED_CONFIG_FOLDER:
|
||||
shrink_config = get_lora_op_configs(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_shrink",
|
||||
max_loras=lora_a_stacked.shape[0],
|
||||
batch=M,
|
||||
hidden_size=lora_a_stacked.shape[-1],
|
||||
rank=lora_a_stacked.shape[-2],
|
||||
num_slices=num_slices,
|
||||
moe_intermediate_size=lora_b_stacked.shape[-2],
|
||||
)
|
||||
expand_config = get_lora_op_configs(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_expand",
|
||||
max_loras=lora_a_stacked.shape[0],
|
||||
batch=M,
|
||||
hidden_size=lora_a_stacked.shape[-1],
|
||||
rank=lora_a_stacked.shape[-2],
|
||||
num_slices=num_slices,
|
||||
moe_intermediate_size=lora_b_stacked.shape[-2],
|
||||
)
|
||||
else: # fall back to the default config
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
)
|
||||
shrink_config = get_config_func(M)
|
||||
expand_config = get_config_func(M)
|
||||
shrink_config = self._normalize_keys(shrink_config)
|
||||
expand_config = self._normalize_keys(expand_config)
|
||||
return shrink_config, expand_config
|
||||
|
||||
def _inject_lora_into_fused_moe(self):
|
||||
moe_state_dict = {}
|
||||
top_k = self.base_layer.top_k
|
||||
|
||||
self.base_layer.ensure_moe_quant_config_init()
|
||||
quant_config = self.base_layer.quant_method.moe_quant_config
|
||||
|
||||
m_fused_moe_fn = (
|
||||
modular_triton_fused_moe(
|
||||
quant_config, shared_experts=self.base_layer.shared_experts
|
||||
)
|
||||
if not quant_config.use_mxfp4_w4a16
|
||||
else modular_marlin_fused_moe(
|
||||
quant_config, shared_experts=self.base_layer.shared_experts
|
||||
)
|
||||
)
|
||||
|
||||
def fwd_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
||||
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
||||
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
||||
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
||||
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
||||
"apply_router_weight_on_input"
|
||||
]
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def act_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
_, output, input = args
|
||||
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
curr_topk_ids = moe_state_dict["topk_ids"]
|
||||
|
||||
expert_map = moe_state_dict["expert_map"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w13",
|
||||
lora_a_stacked=self.w1_lora_a_stacked,
|
||||
lora_b_stacked=self.w1_lora_b_stacked,
|
||||
num_slices=2,
|
||||
M=M,
|
||||
layer=layer,
|
||||
top_k=top_k,
|
||||
config_dtype=config_dtype,
|
||||
)
|
||||
|
||||
# get the block size of m from customized config or default config
|
||||
max_loras = self.w1_lora_a_stacked.shape[0]
|
||||
(
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
) = self.punica_wrapper.moe_lora_align_block_size(
|
||||
curr_topk_ids,
|
||||
num_tokens,
|
||||
shrink_config["BLOCK_SIZE_M"],
|
||||
self.base_layer.local_num_experts,
|
||||
max_loras,
|
||||
self.adapter_enabled,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
|
||||
moe_state_dict["expert_ids_lora"] = expert_ids_lora
|
||||
moe_state_dict["num_tokens_post_padded_lora"] = (
|
||||
num_tokens_post_padded_lora
|
||||
)
|
||||
|
||||
w13_lora_a_stacked = [self.w1_lora_a_stacked, self.w3_lora_a_stacked]
|
||||
w13_lora_b_stacked = [self.w1_lora_b_stacked, self.w3_lora_b_stacked]
|
||||
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
|
||||
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
||||
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
input.view(-1, top_k, input.shape[-1]),
|
||||
hidden_states,
|
||||
w13_lora_a_stacked,
|
||||
w13_lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
shrink_config, ## pass the shrink config
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
moe_state_dict["intermediate_cache2"] = output
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def moe_sum_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w2",
|
||||
lora_a_stacked=self.w2_lora_a_stacked,
|
||||
lora_b_stacked=self.w2_lora_b_stacked,
|
||||
num_slices=1,
|
||||
M=M,
|
||||
layer=layer,
|
||||
top_k=top_k,
|
||||
config_dtype=config_dtype,
|
||||
)
|
||||
|
||||
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
|
||||
expert_ids_lora = moe_state_dict["expert_ids_lora"]
|
||||
num_tokens_post_padded_lora = moe_state_dict[
|
||||
"num_tokens_post_padded_lora"
|
||||
]
|
||||
max_loras = self.w1_lora_a_stacked.shape[0]
|
||||
expert_ids_lora = expert_ids_lora.view(max_loras, -1)
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(max_loras, -1)
|
||||
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
||||
intermediate_cache3 = args[0]
|
||||
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
intermediate_cache3,
|
||||
intermediate_cache2,
|
||||
[self.w2_lora_a_stacked],
|
||||
[self.w2_lora_b_stacked],
|
||||
topk_weights,
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
shrink_config, ## pass the shrink config
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
True,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
fused_experts = m_fused_moe_fn.fused_experts
|
||||
|
||||
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
|
||||
fused_experts.activation = act_decorator(
|
||||
self.base_layer, fused_experts.activation
|
||||
)
|
||||
fused_experts.moe_sum = moe_sum_decorator(
|
||||
self.base_layer, fused_experts.moe_sum
|
||||
)
|
||||
|
||||
self.base_layer.quant_method = FusedMoEModularMethod(
|
||||
self.base_layer.quant_method, m_fused_moe_fn
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
self.w1_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.w1_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.w2_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.w2_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.w3_lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.w3_lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
||||
# to create a dummy LoRA weights.
|
||||
self.lora_a_stacked = []
|
||||
self.lora_b_stacked = []
|
||||
for lora_id in range(max_loras):
|
||||
for experts_id in range(self.base_layer.local_num_experts):
|
||||
# gate_proj,down_proj,up_proj
|
||||
self.lora_a_stacked.append(self.w1_lora_a_stacked[lora_id][experts_id])
|
||||
self.lora_a_stacked.append(self.w2_lora_a_stacked[lora_id][experts_id])
|
||||
self.lora_a_stacked.append(self.w3_lora_a_stacked[lora_id][experts_id])
|
||||
|
||||
self.lora_b_stacked.append(self.w1_lora_b_stacked[lora_id][experts_id])
|
||||
self.lora_b_stacked.append(self.w2_lora_b_stacked[lora_id][experts_id])
|
||||
self.lora_b_stacked.append(self.w3_lora_b_stacked[lora_id][experts_id])
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
self.w1_lora_a_stacked[index] = 0
|
||||
self.w1_lora_b_stacked[index] = 0
|
||||
self.w3_lora_a_stacked[index] = 0
|
||||
self.w3_lora_b_stacked[index] = 0
|
||||
self.w2_lora_a_stacked[index] = 0
|
||||
self.w2_lora_b_stacked[index] = 0
|
||||
self.adapter_enabled[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
for eid in range(len(lora_a) // 3):
|
||||
w1_lora_a = lora_a[eid * 3]
|
||||
w2_lora_a = lora_a[eid * 3 + 1]
|
||||
w3_lora_a = lora_a[eid * 3 + 2]
|
||||
w1_lora_b = lora_b[eid * 3]
|
||||
w2_lora_b = lora_b[eid * 3 + 1]
|
||||
w3_lora_b = lora_b[eid * 3 + 2]
|
||||
|
||||
# Handle the case of adding LoRA to only a subset of experts
|
||||
if w1_lora_a is None or w2_lora_a is None or w3_lora_a is None:
|
||||
continue
|
||||
|
||||
if self.tp_size > 1:
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
w1_lora_b = w1_lora_b[start_idx:end_idx, :]
|
||||
w3_lora_b = w3_lora_b[start_idx:end_idx, :]
|
||||
w2_lora_a = w2_lora_a[:, start_idx:end_idx]
|
||||
|
||||
self.w1_lora_a_stacked[
|
||||
index, eid, : w1_lora_a.shape[0], : w1_lora_a.shape[1]
|
||||
].copy_(w1_lora_a, non_blocking=True)
|
||||
|
||||
self.w3_lora_a_stacked[
|
||||
index, eid, : w3_lora_a.shape[0], : w3_lora_a.shape[1]
|
||||
].copy_(w3_lora_a, non_blocking=True)
|
||||
|
||||
self.w2_lora_b_stacked[
|
||||
index, eid, : w2_lora_b.shape[0], : w2_lora_b.shape[1]
|
||||
].copy_(w2_lora_b, non_blocking=True)
|
||||
|
||||
self.w1_lora_b_stacked[
|
||||
index, eid, : w1_lora_b.shape[0], : w1_lora_b.shape[1]
|
||||
].copy_(w1_lora_b, non_blocking=True)
|
||||
self.w3_lora_b_stacked[
|
||||
index, eid, : w3_lora_b.shape[0], : w3_lora_b.shape[1]
|
||||
].copy_(w3_lora_b, non_blocking=True)
|
||||
self.w2_lora_a_stacked[
|
||||
index, eid, : w2_lora_a.shape[0], : w2_lora_a.shape[1]
|
||||
].copy_(w2_lora_a, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
# return type(source_layer) is FusedMoE
|
||||
return isinstance(source_layer, FusedMoE)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.base_layer.forward(*args, **kwargs)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
|
||||
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _shared_experts(self):
|
||||
return self.base_layer._shared_experts
|
||||
|
||||
@property
|
||||
def quant_method(self):
|
||||
return self.base_layer.quant_method
|
||||
|
||||
@property
|
||||
def is_internal_router(self) -> bool:
|
||||
return self.base_layer.is_internal_router
|
||||
@@ -1,252 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import BaseLayerWithLoRA
|
||||
|
||||
|
||||
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
"""
|
||||
LoRA wrapper for LogitsProcessor, with extra logic to handle the
|
||||
application of the LoRA adapter and added LoRA vocabulary.
|
||||
|
||||
Args:
|
||||
base_layer: LogitsProcessor layer
|
||||
hidden_size: hidden size of the model
|
||||
dtype: data type of the model
|
||||
device: device of the model
|
||||
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
|
||||
received from base_layer.get_sharded_to_full_mapping(). If None,
|
||||
no reindexing will be done.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: LogitsProcessor,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
sharded_to_full_mapping: list[int] | None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.sharded_to_full_mapping = sharded_to_full_mapping
|
||||
|
||||
@property
|
||||
def logits_as_input(self):
|
||||
return self.base_layer.logits_as_input
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.base_layer.vocab_size
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self.base_layer.scale
|
||||
|
||||
@property
|
||||
def soft_cap(self):
|
||||
return self.base_layer.soft_cap
|
||||
|
||||
@property
|
||||
def use_all_gather(self):
|
||||
return self.base_layer.use_all_gather
|
||||
|
||||
@property
|
||||
def org_vocab_size(self):
|
||||
return self.base_layer.org_vocab_size
|
||||
|
||||
@property
|
||||
def include_gpu_probs_tensor(self):
|
||||
return self.base_layer.include_gpu_probs_tensor
|
||||
|
||||
@property
|
||||
def should_modify_greedy_probs_inplace(self):
|
||||
return self.base_layer.should_modify_greedy_probs_inplace
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
# TODO: Verify if this condition can be further relaxed
|
||||
if 32000 < self.base_layer.vocab_size > 257024:
|
||||
raise ValueError(
|
||||
"When using LoRA, vocab size must be 32000 >= vocab_size <= 257024"
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
# Pad for kernel compatibility
|
||||
math.ceil(
|
||||
self.base_layer.vocab_size / lora_config.lora_vocab_padding_size
|
||||
)
|
||||
* lora_config.lora_vocab_padding_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.embeddings_tensors = torch.full(
|
||||
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
|
||||
fill_value=float("-inf"),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if self.sharded_to_full_mapping is not None:
|
||||
self.sharded_to_full_mapping_gpu = torch.tensor(
|
||||
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
self.sharded_to_full_mapping_gpu = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = float("-inf")
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True
|
||||
)
|
||||
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index,
|
||||
: embeddings_tensor.shape[0],
|
||||
: embeddings_tensor.shape[1],
|
||||
] = embeddings_tensor
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
|
||||
# Gather logits for TP
|
||||
logits = self.base_layer._gather_logits(logits)
|
||||
|
||||
if logits is None:
|
||||
return None
|
||||
|
||||
if self.sharded_to_full_mapping_gpu is not None:
|
||||
# Reindex full logits tensor to ensure 1:1 mapping between
|
||||
# index and token_id
|
||||
# Example for:
|
||||
# org_vocab_size = 4
|
||||
# added_vocab_size = 2
|
||||
# pad_to_size = 8
|
||||
# tp_size = 2
|
||||
|
||||
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
|
||||
|
||||
# Therefore, the mapping is expected to be:
|
||||
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
|
||||
# we get:
|
||||
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
|
||||
logits = logits[:, self.sharded_to_full_mapping_gpu]
|
||||
|
||||
lora_logits = torch.empty(
|
||||
self.embeddings_tensors.shape[0] + 1,
|
||||
self.embeddings_tensors.shape[1],
|
||||
hidden_states.shape[0],
|
||||
dtype=self.embeddings_tensors.dtype,
|
||||
device=self.embeddings_tensors.device,
|
||||
)
|
||||
torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1])
|
||||
|
||||
neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype)
|
||||
|
||||
lora_logits[-1] = neg_inf
|
||||
lora_logits = lora_logits.mT
|
||||
indices_padded = self.punica_wrapper.sampler_indices_padded
|
||||
|
||||
if current_platform.is_tpu() or current_platform.is_xpu():
|
||||
indices_padded = indices_padded[: logits.size(0)]
|
||||
|
||||
lora_logits = (
|
||||
lora_logits.reshape(
|
||||
lora_logits.shape[0] * lora_logits.shape[1],
|
||||
lora_logits.shape[2],
|
||||
)
|
||||
.index_select(0, indices_padded)
|
||||
.nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)
|
||||
)
|
||||
|
||||
logits[
|
||||
:,
|
||||
self.base_layer.org_vocab_size : self.base_layer.org_vocab_size
|
||||
+ lora_logits.shape[1],
|
||||
] = lora_logits
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
|
||||
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
logits = lora_output
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, : self.base_layer.vocab_size]
|
||||
return logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return type(self.base_layer).forward(self, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
# Special handling for the LogitsProcessor.
|
||||
return False
|
||||
@@ -1,70 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
|
||||
|
||||
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
def __init__(self, base_layer: ReplicatedLinear) -> None:
|
||||
super().__init__(
|
||||
base_layer,
|
||||
)
|
||||
# To ensure interface compatibility, set to 1 always.
|
||||
self.output_size = self.base_layer.output_size
|
||||
self.n_slices = 1
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Forward of ReplicatedLinearWithLoRA
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output = self.apply(input_, bias)
|
||||
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
|
||||
return output, output_bias
|
||||
|
||||
# ReplicatedLinear should always be replaced, regardless of the fully
|
||||
# sharded LoRAs setting, because it is, by definition, copied per GPU.
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ReplicatedLinear
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora a if splitting for tensor parallelism."""
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora b if splitting with tensor parallelism."""
|
||||
return lora_b
|
||||
@@ -1,181 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
|
||||
# reset input_size
|
||||
self.input_size = self.base_layer.input_size_per_partition
|
||||
self.output_size = self.base_layer.output_size
|
||||
# There is only one LoRA layer.
|
||||
self.n_slices = 1
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.input_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_a = lora_a[:, start_idx:end_idx]
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
return lora_b
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: tensor whose last dimension is `input_size`. If
|
||||
`input_is_parallel` is set, then the last dimension
|
||||
is `input_size // tp_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# set up backprop all-reduce.
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply(input_parallel)
|
||||
if self.base_layer.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (
|
||||
output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None
|
||||
else output_
|
||||
)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is RowParallelLinear
|
||||
|
||||
|
||||
# The following layer is based on the tensor parallelism strategy given in
|
||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||
# https://arxiv.org/abs/2311.03285.
|
||||
|
||||
|
||||
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from RowParallelLinearWithLoRA by slicing the
|
||||
LoRA B's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the output dim.
|
||||
This yields a combined partial sum from the row parallel base
|
||||
layer and column partitioned output from the LoRA.
|
||||
"""
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[start_idx:end_idx, :]
|
||||
return lora_b
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffer = torch.zeros(
|
||||
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
|
||||
buffer, x, self.lora_a_stacked, 1.0
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
buffer = shrunk_buffer
|
||||
if self.tp_size > 1:
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
# NOTE offset are based on the rank.
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
offset_start = self.tp_rank * shard_size
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
self.output_slices,
|
||||
offset_start=offset_start,
|
||||
add_input=True,
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
@@ -1,65 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
index_mapping: tuple[int, ...]
|
||||
prompt_mapping: tuple[int, ...]
|
||||
is_prefill: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||
|
||||
|
||||
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
||||
"""Returns the device for where to place the LoRA tensors."""
|
||||
# unquantizedLinear
|
||||
if hasattr(base_layer, "weight"):
|
||||
return base_layer.weight.device
|
||||
# Compressed Tensor
|
||||
elif hasattr(base_layer, "weight_packed"):
|
||||
return base_layer.weight_packed.device
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(base_layer, "qweight"):
|
||||
return base_layer.qweight.device
|
||||
# HQQ marlin
|
||||
elif hasattr(base_layer, "W_q"):
|
||||
return base_layer.W_q.device
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
def _not_fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of not using fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
|
||||
condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
|
||||
return can_replace(*args, **kwargs) and condition
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
return (
|
||||
can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
|
||||
)
|
||||
|
||||
return dec
|
||||
@@ -1,166 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import BaseLayerWithLoRA
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.embeddings_slice: tuple[int, int] | None
|
||||
self.embeddings_weights: torch.Tensor | None
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
if self.base_layer.num_added_embeddings_per_partition > 0:
|
||||
# We can start adding lora weights
|
||||
self.embeddings_weights = self.base_layer.weight.data[
|
||||
self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition # noqa: E501
|
||||
+ self.base_layer.num_added_embeddings_per_partition
|
||||
]
|
||||
self.embeddings_slice = (
|
||||
self.base_layer.shard_indices.added_vocab_start_index
|
||||
- self.base_layer.org_vocab_size,
|
||||
self.base_layer.shard_indices.added_vocab_end_index
|
||||
- self.base_layer.org_vocab_size,
|
||||
)
|
||||
self.base_layer.weight.data[
|
||||
self.base_layer.num_org_embeddings_per_partition :
|
||||
].fill_(0)
|
||||
else:
|
||||
self.embeddings_slice = None
|
||||
self.embeddings_weights = None
|
||||
|
||||
self.embeddings_tensors = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
self.base_layer.embedding_dim,
|
||||
),
|
||||
dtype=self.base_layer.weight.dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.org_vocab_size + lora_config.lora_extra_vocab_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.embedding_dim,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked_2d = self.lora_a_stacked.view(
|
||||
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||
self.lora_a_stacked.shape[2],
|
||||
)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: torch.Tensor | None,
|
||||
):
|
||||
self.reset_lora(index)
|
||||
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
||||
# so we need transpose here
|
||||
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True
|
||||
)
|
||||
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index,
|
||||
: embeddings_tensor.shape[0],
|
||||
: embeddings_tensor.shape[1],
|
||||
].copy_(embeddings_tensor, non_blocking=True)
|
||||
if self.embeddings_slice is not None:
|
||||
# TODO(yard1): Optimize this copy, we don't need to copy
|
||||
# everything, just the modified part
|
||||
embeddings = self.embeddings_tensors.view(
|
||||
self.embeddings_tensors.shape[0] * self.embeddings_tensors.shape[1],
|
||||
self.embeddings_tensors.shape[2],
|
||||
)[self.embeddings_slice[0] : self.embeddings_slice[1]]
|
||||
assert self.embeddings_weights is not None
|
||||
self.embeddings_weights[: embeddings.shape[0]].copy_(embeddings)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0)
|
||||
|
||||
# NB: Don't use torch.narrow here. torch.narrow triggers some
|
||||
# Dynamic Shape specialization in torch.compile
|
||||
num_tokens = x.shape[0]
|
||||
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
|
||||
indices_0 = self.punica_wrapper._embeddings_indices[0][:num_tokens]
|
||||
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices_1,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
full_output = self.base_layer.forward(x + (indices_0 * added_tokens_mask))
|
||||
|
||||
full_output_org = full_output
|
||||
if full_output.ndim == 3:
|
||||
full_output = full_output.view(
|
||||
full_output.shape[0] * full_output.shape[1], -1
|
||||
)
|
||||
if full_lora_a_embeddings.ndim == 3:
|
||||
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
||||
full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1],
|
||||
-1,
|
||||
)
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_embedding(
|
||||
full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
full_output = lora_output
|
||||
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None,
|
||||
) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_layer.weight
|
||||
Reference in New Issue
Block a user