Sync from v0.13
This commit is contained in:
42
vllm/lora/layers/__init__.py
Normal file
42
vllm/lora/layers/__init__.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.layers.column_parallel_linear import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinearWithShardedLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithShardedLoRA,
|
||||
MergedQKVParallelLinearWithLoRA,
|
||||
MergedQKVParallelLinearWithShardedLoRA,
|
||||
QKVParallelLinearWithLoRA,
|
||||
QKVParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.fused_moe import FusedMoE3DWithLoRA, FusedMoEWithLoRA
|
||||
from vllm.lora.layers.logits_processor import LogitsProcessorWithLoRA
|
||||
from vllm.lora.layers.replicated_linear import ReplicatedLinearWithLoRA
|
||||
from vllm.lora.layers.row_parallel_linear import (
|
||||
RowParallelLinearWithLoRA,
|
||||
RowParallelLinearWithShardedLoRA,
|
||||
)
|
||||
from vllm.lora.layers.utils import LoRAMapping
|
||||
from vllm.lora.layers.vocal_parallel_embedding import VocabParallelEmbeddingWithLoRA
|
||||
|
||||
__all__ = [
|
||||
"BaseLayerWithLoRA",
|
||||
"VocabParallelEmbeddingWithLoRA",
|
||||
"LogitsProcessorWithLoRA",
|
||||
"ColumnParallelLinearWithLoRA",
|
||||
"ColumnParallelLinearWithShardedLoRA",
|
||||
"MergedColumnParallelLinearWithLoRA",
|
||||
"MergedColumnParallelLinearWithShardedLoRA",
|
||||
"MergedQKVParallelLinearWithLoRA",
|
||||
"MergedQKVParallelLinearWithShardedLoRA",
|
||||
"QKVParallelLinearWithLoRA",
|
||||
"QKVParallelLinearWithShardedLoRA",
|
||||
"RowParallelLinearWithLoRA",
|
||||
"RowParallelLinearWithShardedLoRA",
|
||||
"ReplicatedLinearWithLoRA",
|
||||
"LoRAMapping",
|
||||
"FusedMoEWithLoRA",
|
||||
"FusedMoE3DWithLoRA",
|
||||
]
|
||||
66
vllm/lora/layers/base.py
Normal file
66
vllm/lora/layers/base.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.lora.punica_wrapper import PunicaWrapperBase
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
def slice_lora_a(
|
||||
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora a if splitting for tensor parallelism."""
|
||||
...
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora b if splitting with tensor parallelism."""
|
||||
...
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
...
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
...
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
punica_wrapper,
|
||||
):
|
||||
self.punica_wrapper: PunicaWrapperBase = punica_wrapper
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
raise NotImplementedError
|
||||
165
vllm/lora/layers/base_linear.py
Normal file
165
vllm/lora/layers/base_linear.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
LinearBase,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import BaseLayerWithLoRA
|
||||
from .utils import _get_lora_device
|
||||
|
||||
|
||||
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: LinearBase):
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.input_size = self.base_layer.input_size
|
||||
# Ensure tp_size and tp_rank consistency with the base_layer.
|
||||
self.tp_size = self.base_layer.tp_size
|
||||
self.tp_rank = self.base_layer.tp_rank
|
||||
self.device = _get_lora_device(self.base_layer)
|
||||
self.output_slices: tuple[int, ...]
|
||||
self.output_size: int
|
||||
self.n_slices: int
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
#
|
||||
if isinstance(self.base_layer, ReplicatedLinear):
|
||||
lora_a_out_size = lora_config.max_lora_rank
|
||||
lora_b_out_size = self.output_size
|
||||
|
||||
elif isinstance(self.base_layer, ColumnParallelLinear):
|
||||
lora_a_out_size = (
|
||||
lora_config.max_lora_rank
|
||||
if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size)
|
||||
)
|
||||
lora_b_out_size = self.output_size
|
||||
|
||||
elif isinstance(self.base_layer, RowParallelLinear):
|
||||
lora_a_out_size = lora_config.max_lora_rank
|
||||
lora_b_out_size = (
|
||||
self.output_size
|
||||
if not lora_config.fully_sharded_loras
|
||||
else divide(self.output_size, self.tp_size)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_out_size,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.n_slices)
|
||||
)
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_b_out_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.n_slices)
|
||||
)
|
||||
self.output_slices = (self.lora_b_stacked[0].shape[2],)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
for s_index in range(self.n_slices):
|
||||
self.lora_a_stacked[s_index][index] = 0
|
||||
self.lora_b_stacked[s_index][index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
# Except for QKVParallelLinearWithLoRA and
|
||||
# MergedColumnParallelLinearWithLoRA, all other linear LoRA layers
|
||||
# store weights in a tuple of size 1. These two layers will
|
||||
# override this function.
|
||||
assert isinstance(lora_a, torch.Tensor)
|
||||
assert isinstance(lora_b, torch.Tensor)
|
||||
assert (
|
||||
len(self.lora_a_stacked) == len(self.lora_b_stacked) == self.n_slices == 1
|
||||
)
|
||||
|
||||
self.reset_lora(index)
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
self.lora_a_stacked[0][index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True
|
||||
)
|
||||
self.lora_b_stacked[0][index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
# In Transformers modeling backend, x and output have extra batch dimension like
|
||||
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
|
||||
# therefore we need to flatten the batch dimensions.
|
||||
if x.ndim == 3 and output.ndim == 3:
|
||||
output = output.flatten(0, 1)
|
||||
x = x.flatten(0, 1)
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_linear(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, self.output_slices
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
return output
|
||||
|
||||
@property
|
||||
def weight(self) -> torch.Tensor:
|
||||
# unquantizedLinear
|
||||
if hasattr(self.base_layer, "weight"):
|
||||
return self.base_layer.weight
|
||||
# Compressed Tensor
|
||||
elif hasattr(self.base_layer, "weight_packed"):
|
||||
return self.base_layer.weight_packed
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(self.base_layer, "qweight"):
|
||||
return self.base_layer.qweight
|
||||
# marlin
|
||||
elif hasattr(self.base_layer, "B"):
|
||||
return self.base_layer.B
|
||||
# HQQ marlin
|
||||
elif hasattr(self.base_layer, "W_q"):
|
||||
return self.base_layer.W_q
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {self.base_layer}")
|
||||
|
||||
@property
|
||||
def bias(self) -> torch.Tensor | None:
|
||||
if hasattr(self.base_layer, "bias"):
|
||||
return self.base_layer.bias
|
||||
else:
|
||||
return None
|
||||
577
vllm/lora/layers/column_parallel_linear.py
Normal file
577
vllm/lora/layers/column_parallel_linear.py
Normal file
@@ -0,0 +1,577 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import tensor_model_parallel_all_gather
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
|
||||
|
||||
|
||||
def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"):
|
||||
"""
|
||||
For `ColumnParallelLinearWithLoRA` or classes that inherit from
|
||||
`ColumnParallelLinearWithLoRA`, they share the same `apply` logic.
|
||||
"""
|
||||
assert (
|
||||
layer.n_slices
|
||||
== len(layer.lora_a_stacked)
|
||||
== len(layer.lora_b_stacked)
|
||||
== len(layer.output_slices)
|
||||
)
|
||||
|
||||
output = layer.base_layer.quant_method.apply(layer.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
|
||||
# Since communication is needed, the buffer is directly initialized as a
|
||||
# tensor rather than a tuple of tensor.
|
||||
buffers = torch.zeros(
|
||||
(layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink(
|
||||
buffers, x, layer.lora_a_stacked, 1.0
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
buffers = shrunk_buffers
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
|
||||
lora_output: torch.Tensor | None = layer.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffers,
|
||||
layer.lora_b_stacked,
|
||||
layer.output_slices,
|
||||
offset_start=0,
|
||||
add_input=True,
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
# now have column partitioned and packed output
|
||||
return output
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
"""
|
||||
LoRA on top of ColumnParallelLinear layer.
|
||||
LoRA B is sliced for tensor parallelism.
|
||||
There are two types for the `base_layer`:
|
||||
1. ColumnParallelLinear, e.g.`dense_h_to_4h` in `FalconForCausalLM`.
|
||||
2. MergedColumnParallelLinear, e.g.`gate_up_proj` in `Phi3ForCausalLM`.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
# The base_layer type is ColumnParallelLinear or
|
||||
# MergedColumnParallelLinear, their weight sharding logic is
|
||||
# inconsistent when TP is greater than 1.
|
||||
self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear
|
||||
self.output_size = self.base_layer.output_size_per_partition
|
||||
# There is only one LoRA layer
|
||||
self.n_slices = 1
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
# Applicable to cases where the base_layer is
|
||||
# MergedColumnParallelLinear.
|
||||
if self.is_merged_col_linear:
|
||||
shard_size = self.output_size // 2
|
||||
offset = lora_b.shape[0] // 2
|
||||
|
||||
left_weight = lora_b[
|
||||
self.tp_rank * shard_size : (self.tp_rank + 1) * shard_size, :
|
||||
]
|
||||
right_weight = lora_b[
|
||||
offset + self.tp_rank * shard_size : offset
|
||||
+ (self.tp_rank + 1) * shard_size,
|
||||
:,
|
||||
]
|
||||
lora_b = torch.cat([left_weight, right_weight], dim=0)
|
||||
# Applicable to cases where the base_layer is
|
||||
# ColumnParallelLinear.
|
||||
else:
|
||||
shard_size = self.output_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[start_idx:end_idx, :]
|
||||
return lora_b
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply(input_, bias)
|
||||
if self.base_layer.gather_output and self.tp_size > 1:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ColumnParallelLinear or (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
and len(packed_modules_list) == 1
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
|
||||
packed together (e.g. gate_proj + up_proj -> gate_up_proj).
|
||||
|
||||
This means we have 2 LoRAs, each applied to one half of the layer.
|
||||
|
||||
Both slices must have the same size.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, base_layer: MergedColumnParallelLinear | QKVParallelLinear
|
||||
) -> None:
|
||||
super().__init__(base_layer)
|
||||
# There are two LoRA layers
|
||||
# the output_sizes in MergedColumnParallelLinear is not sharded by tp
|
||||
# we need to divide it by the tp_size to get correct slices size
|
||||
output_sizes = self.base_layer.output_sizes
|
||||
self.output_slices = tuple(
|
||||
divide(output_size, self.tp_size) for output_size in output_sizes
|
||||
)
|
||||
self.n_slices = len(self.output_slices)
|
||||
self.output_ids = (self.tp_rank,) * self.n_slices
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
The main reason for overriding this function is to enhance code
|
||||
maintainability.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
|
||||
lora_a_output_size_per_partition = (
|
||||
lora_config.max_lora_rank
|
||||
if not lora_config.fully_sharded_loras
|
||||
else divide(lora_config.max_lora_rank, self.tp_size)
|
||||
)
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_a_output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.n_slices)
|
||||
)
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
output_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for output_size in self.output_slices
|
||||
)
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
sliced_lora_b = [None] * self.n_slices
|
||||
for i, (shard_id, shard_size) in enumerate(
|
||||
zip(self.output_ids, self.output_slices)
|
||||
):
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
sliced_lora_b[i] = lora_b_i[
|
||||
shard_size * shard_id : shard_size * (shard_id + 1), :
|
||||
]
|
||||
return sliced_lora_b
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
lora_a = self.slice_lora_a(lora_a)
|
||||
lora_b = self.slice_lora_b(lora_b)
|
||||
|
||||
for i in range(self.n_slices):
|
||||
if (lora_a_i := lora_a[i]) is not None:
|
||||
self.lora_a_stacked[i][
|
||||
index, 0, : lora_a_i.shape[0], : lora_a_i.shape[1]
|
||||
].copy_(lora_a_i, non_blocking=True)
|
||||
if (lora_b_i := lora_b[i]) is not None:
|
||||
self.lora_b_stacked[i][
|
||||
index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1]
|
||||
].copy_(lora_b_i, non_blocking=True)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
and len(packed_modules_list) == 2
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
ColumnParallelLinear layer that is specifically designed for
|
||||
qkv_proj. Certain models, such as chatglm3 and baichuan-7b,
|
||||
only contains a single LoRA within their qkv_proj layer.
|
||||
|
||||
During inference with Tensor Parallel, the weights of lora_b
|
||||
must be accurately partitioned according to the respective ranks.
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
self.q_proj_total_size = (
|
||||
self.base_layer.total_num_heads * self.base_layer.head_size
|
||||
)
|
||||
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
|
||||
self.kv_proj_shard_size = (
|
||||
self.base_layer.num_kv_heads * self.base_layer.head_size
|
||||
)
|
||||
self.kv_proj_total_size = (
|
||||
self.base_layer.total_num_kv_heads * self.base_layer.head_size
|
||||
)
|
||||
# There is only one LoRA layer
|
||||
self.n_slices = 1
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
lora_b_q = lora_b[
|
||||
self.q_proj_shard_size * self.q_shard_id : self.q_proj_shard_size
|
||||
* (self.q_shard_id + 1),
|
||||
:,
|
||||
]
|
||||
k_offset = self.q_proj_total_size
|
||||
lora_b_k = lora_b[
|
||||
k_offset + self.kv_proj_shard_size * self.kv_shard_id : k_offset
|
||||
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
|
||||
:,
|
||||
]
|
||||
v_offset = k_offset + self.kv_proj_total_size
|
||||
lora_b_v = lora_b[
|
||||
v_offset + self.kv_proj_shard_size * self.kv_shard_id : v_offset
|
||||
+ self.kv_proj_shard_size * (self.kv_shard_id + 1),
|
||||
:,
|
||||
]
|
||||
lora_b = torch.cat([lora_b_q, lora_b_k, lora_b_v], dim=0)
|
||||
return lora_b
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 1
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
"""MergedColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||
packed together in qkv proj fashion
|
||||
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||
|
||||
This means we have 3 LoRAs, each applied to one slice of the layer.
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
# There are three LoRA layer.
|
||||
self.n_slices = len(self.base_layer.output_sizes)
|
||||
|
||||
self.q_proj_shard_size = self.base_layer.num_heads * self.base_layer.head_size
|
||||
self.kv_proj_shard_size = (
|
||||
self.base_layer.num_kv_heads * self.base_layer.head_size
|
||||
)
|
||||
self.q_shard_id = self.tp_rank
|
||||
self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
self.output_slices = (
|
||||
self.q_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
)
|
||||
self.output_ids = (
|
||||
self.q_shard_id,
|
||||
self.kv_shard_id,
|
||||
self.kv_shard_id,
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
The main reason for overloading this function is to handle inconsistent
|
||||
weight dimensions in qkv lora.
|
||||
"""
|
||||
super().create_lora_weights(max_loras, lora_config, model_config)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(packed_modules_list) == 3
|
||||
|
||||
|
||||
# These following layers are based on the tensor parallelism strategy given in
|
||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||
# https://arxiv.org/abs/2311.03285.
|
||||
|
||||
|
||||
class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
# For all LoRA layers where the `base_layer` is `ColumnParallelLinear`,
|
||||
# their `lora_a` and `lora_b` have different sharding patterns. After
|
||||
# completing the `lora_a` GEMM , a gather operation is performed.
|
||||
# Therefore, the sharding of `lora_a` only needs to correspond with the
|
||||
# gather operation.
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
lora_a = lora_a[start_idx : start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
# NOTE: lora_a contains 2 subloras, and each sublora could be None.
|
||||
output_shard_size = self.lora_a_stacked[0].shape[2]
|
||||
output_start_idx = self.tp_rank * output_shard_size
|
||||
lora_a = [
|
||||
lora_a[0][output_start_idx : output_start_idx + output_shard_size, :]
|
||||
if lora_a[0] is not None
|
||||
else None,
|
||||
lora_a[1][output_start_idx : output_start_idx + output_shard_size, :]
|
||||
if lora_a[1] is not None
|
||||
else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class QKVParallelLinearWithShardedLoRA(QKVParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from QKVParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_a_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
lora_a = lora_a[start_idx : start_idx + shard_size, :]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
|
||||
|
||||
class MergedQKVParallelLinearWithShardedLoRA(MergedQKVParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedQKVParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
"""
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: list[torch.Tensor | None]
|
||||
) -> list[torch.Tensor | None]:
|
||||
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
|
||||
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
|
||||
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
|
||||
lora_a = [
|
||||
lora_a[0][start_idx[0] : start_idx[0] + shard_size[0], :]
|
||||
if lora_a[0] is not None
|
||||
else None,
|
||||
lora_a[1][start_idx[1] : start_idx[1] + shard_size[1], :]
|
||||
if lora_a[1] is not None
|
||||
else None,
|
||||
lora_a[2][start_idx[2] : start_idx[2] + shard_size[2], :]
|
||||
if lora_a[2] is not None
|
||||
else None,
|
||||
]
|
||||
return lora_a
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
return _mcp_apply(x, bias, self)
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
747
vllm/lora/layers/fused_moe.py
Normal file
747
vllm/lora/layers/fused_moe.py
Normal file
@@ -0,0 +1,747 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.lora.layers.base import BaseLayerWithLoRA
|
||||
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
_get_config_dtype_str,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
|
||||
MarlinExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_modular_method import (
|
||||
FusedMoEModularMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
|
||||
UnfusedOAITritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEModularKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
|
||||
from .utils import _get_lora_device
|
||||
|
||||
|
||||
class FusedMoEWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: FusedMoE) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
assert not self.base_layer.use_ep, (
|
||||
"EP support for Fused MoE LoRA is not implemented yet."
|
||||
)
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.device = _get_lora_device(base_layer)
|
||||
self._w13_slices = 2
|
||||
self._inject_lora_into_fused_moe()
|
||||
|
||||
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
|
||||
normalized_config = {}
|
||||
for key, value in config.items():
|
||||
if key.islower():
|
||||
if key.startswith("block_"):
|
||||
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
|
||||
else:
|
||||
normalized_key = key.upper()
|
||||
else:
|
||||
normalized_key = key
|
||||
normalized_config[normalized_key] = value
|
||||
return normalized_config
|
||||
|
||||
def _get_lora_moe_configs(
|
||||
self,
|
||||
op_prefix: str,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
num_slices: int,
|
||||
M: int,
|
||||
layer: FusedMoE,
|
||||
top_k: int,
|
||||
config_dtype: str,
|
||||
):
|
||||
if envs.VLLM_TUNED_CONFIG_FOLDER:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size = layer.intermediate_size_per_partition
|
||||
shrink_config = get_lora_op_configs(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_shrink",
|
||||
max_loras=num_loras,
|
||||
batch=M,
|
||||
hidden_size=hidden_size,
|
||||
rank=rank,
|
||||
num_slices=num_slices,
|
||||
moe_intermediate_size=intermediate_size,
|
||||
)
|
||||
expand_config = get_lora_op_configs(
|
||||
op_type=f"fused_moe_lora_{op_prefix}_expand",
|
||||
max_loras=num_loras,
|
||||
batch=M,
|
||||
hidden_size=hidden_size, # lora_a_stacked.shape[-1],
|
||||
rank=rank,
|
||||
num_slices=num_slices,
|
||||
moe_intermediate_size=intermediate_size, # lora_b_stacked.shape[-2],
|
||||
)
|
||||
else: # fall back to the default config
|
||||
get_config_func = functools.partial(
|
||||
try_get_optimal_moe_config,
|
||||
layer.w13_weight.size(),
|
||||
layer.w2_weight.size(),
|
||||
top_k,
|
||||
config_dtype,
|
||||
block_shape=layer.quant_method.moe_quant_config.block_shape,
|
||||
)
|
||||
shrink_config = get_config_func(M)
|
||||
expand_config = get_config_func(M)
|
||||
shrink_config = self._normalize_keys(shrink_config)
|
||||
expand_config = self._normalize_keys(expand_config)
|
||||
return shrink_config, expand_config
|
||||
|
||||
def _inject_lora_into_fused_moe(self):
|
||||
moe_state_dict = {}
|
||||
top_k = self.base_layer.top_k
|
||||
|
||||
self.base_layer.ensure_moe_quant_config_init()
|
||||
quant_config = self.base_layer.quant_method.moe_quant_config
|
||||
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
m_fused_moe_fn = FusedMoEModularKernel(
|
||||
prepare_finalize,
|
||||
self.base_layer.quant_method.select_gemm_impl(
|
||||
prepare_finalize, self.base_layer
|
||||
),
|
||||
self.base_layer.shared_experts,
|
||||
getattr(self.base_layer, "shared_experts_stream", None),
|
||||
)
|
||||
if quant_config.use_mxfp4_w4a16:
|
||||
assert isinstance(
|
||||
m_fused_moe_fn.fused_experts, (MarlinExperts, UnfusedOAITritonExperts)
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
m_fused_moe_fn.fused_experts, (MarlinExperts, TritonExperts)
|
||||
)
|
||||
|
||||
def fwd_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
moe_state_dict["hidden_states"] = kwargs["hidden_states"]
|
||||
moe_state_dict["topk_ids"] = kwargs["topk_ids"]
|
||||
moe_state_dict["topk_weights"] = kwargs["topk_weights"]
|
||||
moe_state_dict["expert_map"] = kwargs["expert_map"]
|
||||
moe_state_dict["apply_router_weight_on_input"] = kwargs[
|
||||
"apply_router_weight_on_input"
|
||||
]
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def act_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
_, output, input = args
|
||||
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
curr_topk_ids = moe_state_dict["topk_ids"]
|
||||
|
||||
expert_map = moe_state_dict["expert_map"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
max_lora_rank = self.w13_lora_a_stacked[0].shape[-2]
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w13",
|
||||
num_loras=self.max_loras,
|
||||
rank=max_lora_rank,
|
||||
num_slices=self._w13_slices,
|
||||
M=M,
|
||||
layer=layer,
|
||||
top_k=top_k,
|
||||
config_dtype=config_dtype,
|
||||
)
|
||||
|
||||
# get the block size of m from customized config or default config
|
||||
(
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
) = self.punica_wrapper.moe_lora_align_block_size(
|
||||
curr_topk_ids,
|
||||
num_tokens,
|
||||
shrink_config["BLOCK_SIZE_M"],
|
||||
self.base_layer.local_num_experts,
|
||||
self.max_loras,
|
||||
self.adapter_enabled,
|
||||
expert_map,
|
||||
)
|
||||
|
||||
moe_state_dict["sorted_token_ids_lora"] = sorted_token_ids_lora
|
||||
moe_state_dict["expert_ids_lora"] = expert_ids_lora
|
||||
moe_state_dict["num_tokens_post_padded_lora"] = (
|
||||
num_tokens_post_padded_lora
|
||||
)
|
||||
|
||||
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
|
||||
#
|
||||
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
input.view(-1, top_k, input.shape[-1]),
|
||||
hidden_states,
|
||||
self.w13_lora_a_stacked,
|
||||
self.w13_lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
shrink_config, ## pass the shrink config
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
fully_sharded=self.fully_sharded,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
moe_state_dict["intermediate_cache2"] = output
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
def moe_sum_decorator(layer, func):
|
||||
def wrapper(*args, **kwargs):
|
||||
hidden_states = moe_state_dict["hidden_states"]
|
||||
topk_weights = moe_state_dict["topk_weights"]
|
||||
|
||||
config_dtype = _get_config_dtype_str(
|
||||
dtype=hidden_states.dtype,
|
||||
use_fp8_w8a8=False,
|
||||
use_int8_w8a16=False,
|
||||
use_int4_w4a16=False,
|
||||
)
|
||||
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
|
||||
num_tokens = hidden_states.size(0)
|
||||
M = min(num_tokens, CHUNK_SIZE)
|
||||
max_lora_rank = self.w2_lora_a_stacked[0].shape[-2]
|
||||
shrink_config, expand_config = self._get_lora_moe_configs(
|
||||
op_prefix="w2",
|
||||
num_loras=self.max_loras,
|
||||
rank=max_lora_rank,
|
||||
num_slices=1,
|
||||
M=M,
|
||||
layer=layer,
|
||||
top_k=top_k,
|
||||
config_dtype=config_dtype,
|
||||
)
|
||||
|
||||
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
|
||||
expert_ids_lora = moe_state_dict["expert_ids_lora"]
|
||||
num_tokens_post_padded_lora = moe_state_dict[
|
||||
"num_tokens_post_padded_lora"
|
||||
]
|
||||
|
||||
expert_ids_lora = expert_ids_lora.view(self.max_loras, -1)
|
||||
sorted_token_ids_lora = sorted_token_ids_lora.view(self.max_loras, -1)
|
||||
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
|
||||
intermediate_cache3 = args[0]
|
||||
|
||||
shard_size_w2 = divide(self.base_layer.hidden_size, self.tp_size)
|
||||
|
||||
self.punica_wrapper.add_lora_fused_moe(
|
||||
intermediate_cache3,
|
||||
intermediate_cache2,
|
||||
self.w2_lora_a_stacked,
|
||||
self.w2_lora_b_stacked,
|
||||
topk_weights,
|
||||
sorted_token_ids_lora,
|
||||
expert_ids_lora,
|
||||
num_tokens_post_padded_lora,
|
||||
max_lora_rank,
|
||||
top_k,
|
||||
shrink_config, ## pass the shrink config
|
||||
expand_config, ## pass the expand config
|
||||
self.adapter_enabled,
|
||||
True,
|
||||
fully_sharded=self.fully_sharded,
|
||||
offset=shard_size_w2 * self.tp_rank if self.fully_sharded else 0,
|
||||
)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
fused_experts = m_fused_moe_fn.fused_experts
|
||||
|
||||
m_fused_moe_fn.forward = fwd_decorator(self.base_layer, m_fused_moe_fn.forward)
|
||||
fused_experts.activation = act_decorator(
|
||||
self.base_layer, fused_experts.activation
|
||||
)
|
||||
fused_experts.moe_sum = moe_sum_decorator(
|
||||
self.base_layer, fused_experts.moe_sum
|
||||
)
|
||||
self.base_layer.quant_method = FusedMoEModularMethod(
|
||||
self.base_layer.quant_method, m_fused_moe_fn
|
||||
)
|
||||
|
||||
def _create_lora_a_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
):
|
||||
self.w13_lora_a_stacked: tuple[torch.Tensor, ...] = tuple(
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank
|
||||
if not self.fully_sharded
|
||||
else divide(lora_config.max_lora_rank, self.tp_size),
|
||||
self.base_layer.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self._w13_slices)
|
||||
)
|
||||
self.w2_lora_a_stacked: tuple[torch.Tensor, ...] = (
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
def _create_lora_b_weights(self, max_loras: int, lora_config: LoRAConfig):
|
||||
self.w13_lora_b_stacked: tuple[torch.Tensor, ...] = tuple(
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self._w13_slices)
|
||||
)
|
||||
self.w2_lora_b_stacked: tuple[torch.Tensor, ...] = (
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size
|
||||
if not self.fully_sharded
|
||||
else divide(self.base_layer.hidden_size, self.tp_size),
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
self.max_loras = lora_config.max_loras
|
||||
self.fully_sharded = lora_config.fully_sharded_loras
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
self._create_lora_a_weights(max_loras, lora_config)
|
||||
self._create_lora_b_weights(max_loras, lora_config)
|
||||
# They will be used by 'LoRALayerWeights.create_dummy_lora_weights'
|
||||
# to create a dummy LoRA weights.
|
||||
# TODO Optimize this section
|
||||
self.lora_a_stacked = []
|
||||
self.lora_b_stacked = []
|
||||
for lora_id in range(max_loras):
|
||||
for experts_id in range(self.base_layer.local_num_experts):
|
||||
# gate_proj,down_proj,up_proj
|
||||
self.lora_a_stacked.append(
|
||||
self.w13_lora_a_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
self.lora_a_stacked.append(
|
||||
self.w2_lora_a_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
|
||||
self.lora_b_stacked.append(
|
||||
self.w13_lora_b_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
self.lora_b_stacked.append(
|
||||
self.w2_lora_b_stacked[0][lora_id][experts_id]
|
||||
)
|
||||
|
||||
self.lora_a_stacked.append(
|
||||
self.w13_lora_a_stacked[1][lora_id][experts_id]
|
||||
)
|
||||
self.lora_b_stacked.append(
|
||||
self.w13_lora_b_stacked[1][lora_id][experts_id]
|
||||
)
|
||||
|
||||
def _slice_w13_a(self, w13_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
|
||||
"""
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w13_lora_a
|
||||
|
||||
# w13_lora_a shape (num_experts,rank,input_size)
|
||||
current_lora_rank = w13_lora_a.shape[1]
|
||||
assert current_lora_rank % self.tp_size == 0
|
||||
# Based on S-LoRA, we slice W13/W1/W3 A along the rank dim.
|
||||
sliced_rank = current_lora_rank // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_rank
|
||||
end_idx = (self.tp_rank + 1) * sliced_rank
|
||||
return w13_lora_a[:, start_idx:end_idx, :]
|
||||
|
||||
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
|
||||
if self.tp_size == 1:
|
||||
return w13_lora_b
|
||||
|
||||
# w13_lora_b shape (num_experts,output_size,rank)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
return w13_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
def _slice_w2_a(self, w2_lora_a: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
|
||||
"""
|
||||
if self.tp_size == 1:
|
||||
return w2_lora_a
|
||||
# w2_lora_a shape (num_experts,rank,input_size)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
|
||||
return w2_lora_a[:, :, start_idx:end_idx]
|
||||
|
||||
def _slice_w2_b(self, w2_lora_b: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Applies to FusedMoEWithLoRA and FusedMoE3DWithLoRA
|
||||
"""
|
||||
if self.tp_size == 1 or not self.fully_sharded:
|
||||
return w2_lora_b
|
||||
# Based on S-LoRA, we slice W2 B along the hidden_size dim.
|
||||
# w2_lora_b shape (num_experts,output_size,rank)
|
||||
current_lora_size = w2_lora_b.shape[1]
|
||||
|
||||
sliced_size = current_lora_size // self.tp_size
|
||||
start_idx = self.tp_rank * sliced_size
|
||||
end_idx = (self.tp_rank + 1) * sliced_size
|
||||
return w2_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
for pos in range(self._w13_slices):
|
||||
self.w13_lora_a_stacked[pos][index] = 0
|
||||
self.w13_lora_b_stacked[pos][index] = 0
|
||||
|
||||
self.w2_lora_a_stacked[0][index] = 0
|
||||
self.w2_lora_b_stacked[0][index] = 0
|
||||
self.adapter_enabled[index] = 0
|
||||
|
||||
#
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
# Make mypy happy
|
||||
assert isinstance(lora_a, list)
|
||||
assert isinstance(lora_b, list)
|
||||
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
|
||||
num_experts = self.w13_lora_a_stacked[0].shape[1]
|
||||
|
||||
w1_lora_a, w2_lora_a, w3_lora_a = lora_a
|
||||
w1_lora_b, w2_lora_b, w3_lora_b = lora_b
|
||||
assert (
|
||||
num_experts
|
||||
== w1_lora_a.shape[0]
|
||||
== w2_lora_a.shape[0]
|
||||
== w3_lora_a.shape[0]
|
||||
)
|
||||
|
||||
slliced_w1_lora_a = self._slice_w13_a(w1_lora_a)
|
||||
slliced_w1_lora_b = self._slice_w13_b(w1_lora_b)
|
||||
slliced_w3_lora_a = self._slice_w13_a(w3_lora_a)
|
||||
slliced_w3_lora_b = self._slice_w13_b(w3_lora_b)
|
||||
|
||||
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
|
||||
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
|
||||
|
||||
self.w13_lora_a_stacked[0][
|
||||
index, :, : slliced_w1_lora_a.shape[1], : slliced_w1_lora_a.shape[2]
|
||||
].copy_(slliced_w1_lora_a, non_blocking=True)
|
||||
|
||||
self.w13_lora_a_stacked[1][
|
||||
index, :, : slliced_w3_lora_a.shape[1], : slliced_w3_lora_a.shape[2]
|
||||
].copy_(slliced_w3_lora_a, non_blocking=True)
|
||||
|
||||
self.w13_lora_b_stacked[0][
|
||||
index, :, : slliced_w1_lora_b.shape[1], : slliced_w1_lora_b.shape[2]
|
||||
].copy_(slliced_w1_lora_b, non_blocking=True)
|
||||
|
||||
self.w13_lora_b_stacked[1][
|
||||
index, :, : slliced_w3_lora_b.shape[1], : slliced_w3_lora_b.shape[2]
|
||||
].copy_(slliced_w3_lora_b, non_blocking=True)
|
||||
|
||||
self.w2_lora_a_stacked[0][
|
||||
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
|
||||
].copy_(sliced_w2_lora_a, non_blocking=True)
|
||||
|
||||
self.w2_lora_b_stacked[0][
|
||||
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
|
||||
].copy_(sliced_w2_lora_b, non_blocking=True)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.base_layer.forward(*args, **kwargs)
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(self, *args, **kwargs):
|
||||
return self.base_layer.maybe_all_reduce_tensor_model_parallel(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def _shared_experts(self):
|
||||
return self.base_layer._shared_experts
|
||||
|
||||
@property
|
||||
def quant_method(self):
|
||||
return self.base_layer.quant_method
|
||||
|
||||
@property
|
||||
def is_internal_router(self) -> bool:
|
||||
return self.base_layer.is_internal_router
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
|
||||
# source_layer is FusedMoE or SharedFusedMoE
|
||||
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 2
|
||||
|
||||
|
||||
class FusedMoE3DWithLoRA(FusedMoEWithLoRA):
|
||||
def __init__(self, base_layer):
|
||||
super().__init__(base_layer)
|
||||
self._w13_slices = 1
|
||||
|
||||
def _create_lora_b_weights(self, max_loras, lora_config):
|
||||
self.w13_lora_b_stacked: tuple[torch.Tensor] = tuple(
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.intermediate_size_per_partition * 2,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self._w13_slices)
|
||||
)
|
||||
self.w2_lora_b_stacked: tuple[torch.Tensor] = (
|
||||
torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.local_num_experts,
|
||||
self.base_layer.hidden_size
|
||||
if not self.fully_sharded
|
||||
else divide(self.base_layer.hidden_size, self.tp_size),
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
|
||||
assert isinstance(model_config, PretrainedConfig)
|
||||
self._base_model = model_config.architectures[0]
|
||||
self.max_loras = lora_config.max_loras
|
||||
self.fully_sharded = lora_config.fully_sharded_loras
|
||||
|
||||
self.adapter_enabled = torch.tensor(
|
||||
[0] * (max_loras + 1), dtype=torch.int, device=self.device
|
||||
)
|
||||
|
||||
self._create_lora_a_weights(max_loras, lora_config)
|
||||
self._create_lora_b_weights(max_loras, lora_config)
|
||||
|
||||
def _slice_w13_b(self, w13_lora_b: torch.Tensor):
|
||||
if self.tp_size == 1:
|
||||
return w13_lora_b
|
||||
|
||||
# w13_lora_b shape (num_experts,output_size,rank)
|
||||
shard_size = self.base_layer.intermediate_size_per_partition
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
# HACK: Currently, only GPT-OSS is in interleaved order
|
||||
if self._base_model == "GptOssForCausalLM":
|
||||
# For models like GPT-OSS, the weights of w1 (gate_proj) and w3 (up_proj)
|
||||
# in the interleaved order, and corresponding LoRA need to be processed.
|
||||
w1_lora_b = w13_lora_b[:, ::2, :]
|
||||
w3_lora_b = w13_lora_b[:, 1::2, :]
|
||||
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
|
||||
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
return torch.stack([sliced_w1_lora_b, sliced_w3_lora_b], dim=2).flatten(
|
||||
1, 2
|
||||
)
|
||||
else:
|
||||
slice_size = w13_lora_b.shape[1] // 2
|
||||
w1_lora_b = w13_lora_b[:, :slice_size, :]
|
||||
w3_lora_b = w13_lora_b[:, slice_size:, :]
|
||||
sliced_w1_lora_b = w1_lora_b[:, start_idx:end_idx, :]
|
||||
sliced_w3_lora_b = w3_lora_b[:, start_idx:end_idx, :]
|
||||
|
||||
return torch.cat([sliced_w1_lora_b, sliced_w3_lora_b], dim=1)
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
# Make mypy happy
|
||||
assert isinstance(lora_a, list)
|
||||
assert isinstance(lora_b, list)
|
||||
assert len(lora_a) == len(lora_b) == 2
|
||||
|
||||
self.reset_lora(index)
|
||||
self.adapter_enabled[index] = 1
|
||||
|
||||
num_experts = self.w13_lora_a_stacked[0].shape[1]
|
||||
w13_lora_a, w2_lora_a = lora_a
|
||||
w13_lora_b, w2_lora_b = lora_b
|
||||
|
||||
# (num_experts,rank,input_size)
|
||||
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
|
||||
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
|
||||
# (output_size,num_experts,rank)
|
||||
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
|
||||
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
|
||||
# (num_experts,output_size,rank)
|
||||
w13_lora_b = w13_lora_b.permute(1, 0, 2)
|
||||
w2_lora_b = w2_lora_b.permute(1, 0, 2)
|
||||
|
||||
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
|
||||
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)
|
||||
|
||||
sliced_w2_lora_a = self._slice_w2_a(w2_lora_a)
|
||||
sliced_w2_lora_b = self._slice_w2_b(w2_lora_b)
|
||||
|
||||
self.w13_lora_a_stacked[0][
|
||||
index, :, : sliced_w13_lora_a.shape[1], : sliced_w13_lora_a.shape[2]
|
||||
].copy_(sliced_w13_lora_a, non_blocking=True)
|
||||
self.w2_lora_a_stacked[0][
|
||||
index, :, : sliced_w2_lora_a.shape[1], : sliced_w2_lora_a.shape[2]
|
||||
].copy_(sliced_w2_lora_a, non_blocking=True)
|
||||
|
||||
self.w13_lora_b_stacked[0][
|
||||
index, :, : sliced_w13_lora_b.shape[1], : sliced_w13_lora_b.shape[2]
|
||||
].copy_(sliced_w13_lora_b, non_blocking=True)
|
||||
self.w2_lora_b_stacked[0][
|
||||
index, :, : sliced_w2_lora_b.shape[1], : sliced_w2_lora_b.shape[2]
|
||||
].copy_(sliced_w2_lora_b, non_blocking=True)
|
||||
|
||||
@property
|
||||
def w13_input_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w13_lora_a_stacked[0].shape[-1]
|
||||
|
||||
@property
|
||||
def w13_output_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w13_lora_b_stacked[0].shape[-2] * self.tp_size
|
||||
|
||||
@property
|
||||
def w2_input_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w2_lora_a_stacked[0].shape[-1] * self.tp_size
|
||||
|
||||
@property
|
||||
def w2_output_size(self):
|
||||
"""
|
||||
Full size
|
||||
"""
|
||||
return self.w2_lora_a_stacked[0].shape[-2]
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
# source_layer is FusedMoE or SharedFusedMoE
|
||||
return isinstance(source_layer, FusedMoE) and len(packed_modules_list) == 1
|
||||
203
vllm/lora/layers/logits_processor.py
Normal file
203
vllm/lora/layers/logits_processor.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import BaseLayerWithLoRA
|
||||
|
||||
|
||||
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
"""
|
||||
LoRA wrapper for LogitsProcessor, with extra logic to handle the
|
||||
application of the LoRA adapter and added LoRA vocabulary.
|
||||
|
||||
Args:
|
||||
base_layer: LogitsProcessor layer
|
||||
hidden_size: hidden size of the model
|
||||
dtype: data type of the model
|
||||
device: device of the model
|
||||
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
|
||||
received from base_layer.get_sharded_to_full_mapping(). If None,
|
||||
no reindexing will be done.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: LogitsProcessor,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
sharded_to_full_mapping: list[int] | None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.sharded_to_full_mapping = sharded_to_full_mapping
|
||||
|
||||
@property
|
||||
def logits_as_input(self):
|
||||
return self.base_layer.logits_as_input
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.base_layer.vocab_size
|
||||
|
||||
@property
|
||||
def scale(self):
|
||||
return self.base_layer.scale
|
||||
|
||||
@property
|
||||
def soft_cap(self):
|
||||
return self.base_layer.soft_cap
|
||||
|
||||
@property
|
||||
def use_all_gather(self):
|
||||
return self.base_layer.use_all_gather
|
||||
|
||||
@property
|
||||
def org_vocab_size(self):
|
||||
return self.base_layer.org_vocab_size
|
||||
|
||||
@property
|
||||
def include_gpu_probs_tensor(self):
|
||||
return self.base_layer.include_gpu_probs_tensor
|
||||
|
||||
@property
|
||||
def should_modify_greedy_probs_inplace(self):
|
||||
return self.base_layer.should_modify_greedy_probs_inplace
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
# TODO: Verify if this condition can be further relaxed
|
||||
if 32000 < self.base_layer.vocab_size > 257024:
|
||||
raise ValueError(
|
||||
"When using LoRA, vocab size must be 32000 >= vocab_size <= 257024"
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.vocab_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if self.sharded_to_full_mapping is not None:
|
||||
self.sharded_to_full_mapping_gpu = torch.tensor(
|
||||
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
|
||||
)
|
||||
else:
|
||||
self.sharded_to_full_mapping_gpu = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
assert isinstance(lora_a, torch.Tensor)
|
||||
assert isinstance(lora_b, torch.Tensor)
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True
|
||||
)
|
||||
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
lm_head: VocabParallelEmbedding,
|
||||
embedding_bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | None:
|
||||
# Get the logits for the next tokens.
|
||||
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
|
||||
# Gather logits for TP
|
||||
logits = self.base_layer._gather_logits(logits)
|
||||
|
||||
if logits is None:
|
||||
return None
|
||||
|
||||
if self.sharded_to_full_mapping_gpu is not None:
|
||||
# Reindex full logits tensor to ensure 1:1 mapping between
|
||||
# index and token_id
|
||||
# Example for:
|
||||
# org_vocab_size = 4
|
||||
# added_vocab_size = 2
|
||||
# pad_to_size = 8
|
||||
# tp_size = 2
|
||||
|
||||
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
|
||||
|
||||
# Therefore, the mapping is expected to be:
|
||||
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
|
||||
# we get:
|
||||
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
|
||||
logits = logits[:, self.sharded_to_full_mapping_gpu]
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
|
||||
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
logits = lora_output
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, : self.base_layer.vocab_size]
|
||||
return logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return type(self.base_layer).forward(self, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# Special handling for the LogitsProcessor.
|
||||
return False
|
||||
70
vllm/lora/layers/replicated_linear.py
Normal file
70
vllm/lora/layers/replicated_linear.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
|
||||
|
||||
class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
def __init__(self, base_layer: ReplicatedLinear) -> None:
|
||||
super().__init__(
|
||||
base_layer,
|
||||
)
|
||||
# To ensure interface compatibility, set to 1 always.
|
||||
self.output_size = self.base_layer.output_size
|
||||
self.n_slices = 1
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Forward of ReplicatedLinearWithLoRA
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output = self.apply(input_, bias)
|
||||
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
|
||||
return output, output_bias
|
||||
|
||||
# ReplicatedLinear should always be replaced, regardless of the fully
|
||||
# sharded LoRAs setting, because it is, by definition, copied per GPU.
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is ReplicatedLinear
|
||||
|
||||
def slice_lora_a(
|
||||
self, lora_a: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora a if splitting for tensor parallelism."""
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(
|
||||
self, lora_b: torch.Tensor | list[torch.Tensor | None]
|
||||
) -> torch.Tensor | list[torch.Tensor | None]:
|
||||
"""Slice lora b if splitting with tensor parallelism."""
|
||||
return lora_b
|
||||
176
vllm/lora/layers/row_parallel_linear.py
Normal file
176
vllm/lora/layers/row_parallel_linear.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.distributed import (
|
||||
split_tensor_along_last_dim,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base_linear import BaseLinearLayerWithLoRA
|
||||
from .utils import _fully_sharded_can_replace, _not_fully_sharded_can_replace
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLinearLayerWithLoRA):
|
||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
|
||||
# reset input_size
|
||||
self.input_size = self.base_layer.input_size_per_partition
|
||||
self.output_size = self.base_layer.output_size
|
||||
# There is only one LoRA layer.
|
||||
self.n_slices = 1
|
||||
|
||||
def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.input_size
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_a = lora_a[:, start_idx:end_idx]
|
||||
return lora_a
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
return lora_b
|
||||
|
||||
def forward(
|
||||
self, input_: torch.Tensor
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: tensor whose last dimension is `input_size`. If
|
||||
`input_is_parallel` is set, then the last dimension
|
||||
is `input_size // tp_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# set up backprop all-reduce.
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size
|
||||
)
|
||||
input_parallel = splitted_input[self.tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
bias_ = (
|
||||
None
|
||||
if (self.tp_rank > 0 or self.base_layer.skip_bias_add)
|
||||
else self.base_layer.bias
|
||||
)
|
||||
output_parallel = self.apply(input_parallel, bias_)
|
||||
if self.base_layer.reduce_results and self.tp_size > 1:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
if not self.base_layer.return_bias:
|
||||
return output
|
||||
|
||||
return output, output_bias
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is RowParallelLinear
|
||||
|
||||
|
||||
# The following layer is based on the tensor parallelism strategy given in
|
||||
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
|
||||
# https://arxiv.org/abs/2311.03285.
|
||||
|
||||
|
||||
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from RowParallelLinearWithLoRA by slicing the
|
||||
LoRA B's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the output dim.
|
||||
This yields a combined partial sum from the row parallel base
|
||||
layer and column partitioned output from the LoRA.
|
||||
"""
|
||||
|
||||
def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor:
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = lora_b[start_idx:end_idx, :]
|
||||
return lora_b
|
||||
|
||||
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffer = torch.zeros(
|
||||
(self.n_slices, x.shape[0], self.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
shrunk_buffer: torch.Tensor | None = self.punica_wrapper.add_shrink(
|
||||
buffer, x, self.lora_a_stacked, 1.0
|
||||
)
|
||||
if not current_platform.can_update_inplace():
|
||||
buffer = shrunk_buffer
|
||||
if self.tp_size > 1:
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
# by adding the column partitioned lora output to a slice of output
|
||||
# tensor, which is a partial sum due to row parallel. All that
|
||||
# remains is a standard all_reduce. User should be aware though that
|
||||
# the output is not the same as a normal row_parallel, it should be
|
||||
# reduced before being used
|
||||
# NOTE offset are based on the rank.
|
||||
shard_size = self.lora_b_stacked[0].shape[2]
|
||||
offset_start = self.tp_rank * shard_size
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_expand(
|
||||
output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
self.output_slices,
|
||||
offset_start=offset_start,
|
||||
add_input=True,
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
output = lora_output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
lora_config=lora_config,
|
||||
packed_modules_list=packed_modules_list,
|
||||
model_config=model_config,
|
||||
decorate=False,
|
||||
)
|
||||
74
vllm/lora/layers/utils.py
Normal file
74
vllm/lora/layers/utils.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
index_mapping: tuple[int, ...]
|
||||
prompt_mapping: tuple[int, ...]
|
||||
is_prefill: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||
|
||||
|
||||
def _get_lora_device(base_layer: nn.Module) -> torch.device:
|
||||
# code borrowed from https://github.com/fmmoret/vllm/blob/fm-support-lora-on-quantized-models/vllm/lora/layers.py#L34
|
||||
"""Returns the device for where to place the LoRA tensors."""
|
||||
# unquantizedLinear
|
||||
if hasattr(base_layer, "weight"):
|
||||
return base_layer.weight.device
|
||||
# Compressed Tensor
|
||||
elif hasattr(base_layer, "weight_packed"):
|
||||
return base_layer.weight_packed.device
|
||||
# GPTQ/AWQ
|
||||
elif hasattr(base_layer, "qweight"):
|
||||
return base_layer.qweight.device
|
||||
# HQQ marlin
|
||||
elif hasattr(base_layer, "W_q"):
|
||||
return base_layer.W_q.device
|
||||
# MoE layer
|
||||
elif hasattr(base_layer, "w2_weight"):
|
||||
return base_layer.w2_weight.device
|
||||
# MoE Compressed Tensor
|
||||
elif hasattr(base_layer, "w2_weight_packed"):
|
||||
return base_layer.w2_weight_packed.device
|
||||
# MoE GPTQ/AWQ/GGUF
|
||||
elif hasattr(base_layer, "w2_qweight"):
|
||||
return base_layer.w2_qweight.device
|
||||
else:
|
||||
raise ValueError(f"Unsupported base layer: {base_layer}")
|
||||
|
||||
|
||||
def _not_fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of not using fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
|
||||
condition = not kwargs["lora_config"].fully_sharded_loras if decorate else True
|
||||
return can_replace(*args, **kwargs) and condition
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _fully_sharded_can_replace(can_replace):
|
||||
"""
|
||||
decorator which adds the condition of fully sharded loras
|
||||
intended to wrap can_replace_layer()
|
||||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
return (
|
||||
can_replace(*args, **kwargs) and kwargs["lora_config"].fully_sharded_loras
|
||||
)
|
||||
|
||||
return dec
|
||||
140
vllm/lora/layers/vocal_parallel_embedding.py
Normal file
140
vllm/lora/layers/vocal_parallel_embedding.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .base import BaseLayerWithLoRA
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.embeddings_slice: tuple[int, int] | None
|
||||
self.embeddings_weights: torch.Tensor | None
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> None:
|
||||
if self.base_layer.num_added_embeddings_per_partition > 0:
|
||||
# We can start adding lora weights
|
||||
self.embeddings_weights = self.base_layer.weight.data[
|
||||
self.base_layer.num_org_embeddings_per_partition : self.base_layer.num_org_embeddings_per_partition # noqa: E501
|
||||
+ self.base_layer.num_added_embeddings_per_partition
|
||||
]
|
||||
self.embeddings_slice = (
|
||||
self.base_layer.shard_indices.added_vocab_start_index
|
||||
- self.base_layer.org_vocab_size,
|
||||
self.base_layer.shard_indices.added_vocab_end_index
|
||||
- self.base_layer.org_vocab_size,
|
||||
)
|
||||
self.base_layer.weight.data[
|
||||
self.base_layer.num_org_embeddings_per_partition :
|
||||
].fill_(0)
|
||||
else:
|
||||
self.embeddings_slice = None
|
||||
self.embeddings_weights = None
|
||||
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.org_vocab_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.embedding_dim,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked_2d = self.lora_a_stacked.view(
|
||||
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||
self.lora_a_stacked.shape[2],
|
||||
)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor | list[torch.Tensor],
|
||||
lora_b: torch.Tensor | list[torch.Tensor],
|
||||
):
|
||||
assert isinstance(lora_a, torch.Tensor)
|
||||
assert isinstance(lora_b, torch.Tensor)
|
||||
self.reset_lora(index)
|
||||
# NOTE self.lora_a_stacked is row-major, and lora_a is col-major,
|
||||
# so we need transpose here
|
||||
|
||||
self.lora_a_stacked[index, : lora_a.shape[1], : lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True
|
||||
)
|
||||
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
||||
lora_b, non_blocking=True
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NB: Don't use torch.narrow here. torch.narrow triggers some
|
||||
# Dynamic Shape specialization in torch.compile
|
||||
num_tokens = x.shape[0]
|
||||
indices_1 = self.punica_wrapper._embeddings_indices[1][:num_tokens]
|
||||
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices_1,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
full_output = self.base_layer.forward(x)
|
||||
|
||||
full_output_org = full_output
|
||||
if full_output.ndim == 3:
|
||||
full_output = full_output.view(
|
||||
full_output.shape[0] * full_output.shape[1], -1
|
||||
)
|
||||
if full_lora_a_embeddings.ndim == 3:
|
||||
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
||||
full_lora_a_embeddings.shape[0] * full_lora_a_embeddings.shape[1],
|
||||
-1,
|
||||
)
|
||||
|
||||
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_embedding(
|
||||
full_output, full_lora_a_embeddings, self.lora_b_stacked, add_input=True
|
||||
)
|
||||
|
||||
if not current_platform.can_update_inplace():
|
||||
full_output = lora_output
|
||||
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: list,
|
||||
model_config: PretrainedConfig | None = None,
|
||||
) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_layer.weight
|
||||
Reference in New Issue
Block a user