init
This commit is contained in:
979
vllm/lora/layers.py
Normal file
979
vllm/lora/layers.py
Normal file
@@ -0,0 +1,979 @@
|
||||
# pylint: disable=unused-argument
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_gather,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
QKVParallelLinear,
|
||||
MergedColumnParallelLinear)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_lora(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
"""Applies lora to each input.
|
||||
|
||||
This method applies all loras to each input. It uses the
|
||||
indices vector to determine which lora yields the
|
||||
correct output. An index of -1 means no lora should be
|
||||
applied. This method adds the final lora results to the
|
||||
output.
|
||||
|
||||
Input shapes:
|
||||
x: (batch_size, hidden_dim)
|
||||
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
|
||||
lora_b_stacked: (num_loras, output_dim, lora_rank)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, output_dim)
|
||||
"""
|
||||
org_output = output
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
def _apply_lora_packed_nslice(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_slices: Tuple[int, ...],
|
||||
):
|
||||
"""Applies lora to each input.
|
||||
|
||||
This method applies all loras to each input. It uses the
|
||||
indices vector to determine which lora yields the
|
||||
correct output. An index of -1 means no lora should be
|
||||
applied. This method adds the final lora results to the
|
||||
output.
|
||||
|
||||
This method is used for layers that are composed of multiple sublayers
|
||||
(slices) packed together.
|
||||
|
||||
Input shapes:
|
||||
x: (batch_size, hidden_dim)
|
||||
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
|
||||
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
|
||||
"""
|
||||
org_output = output
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
offset_left = 0
|
||||
for slice_idx in range(len(output_slices)):
|
||||
add_lora_slice(output, x, lora_a_stacked[slice_idx],
|
||||
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
|
||||
output_slices[slice_idx])
|
||||
offset_left += output_slices[slice_idx]
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
# Per every token in input_ids:
|
||||
index_mapping: Tuple[int, ...]
|
||||
# Per sampled token:
|
||||
prompt_mapping: Tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
...
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
...
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
"""Sets the mapping indices."""
|
||||
...
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
|
||||
lora_vocab_start_idx = self.base_layer.org_vocab_size
|
||||
weights_idx = None
|
||||
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
|
||||
# We can start adding lora weights
|
||||
weights_idx = max(
|
||||
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
|
||||
self.embeddings_slice = (self.base_layer.vocab_start_index -
|
||||
self.base_layer.org_vocab_size +
|
||||
weights_idx,
|
||||
self.base_layer.vocab_end_index -
|
||||
self.base_layer.org_vocab_size)
|
||||
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
|
||||
self.embeddings_weights.fill_(0)
|
||||
else:
|
||||
self.embeddings_slice = None
|
||||
self.embeddings_weights = None
|
||||
|
||||
self.embeddings_tensors = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
self.base_layer.embedding_dim,
|
||||
),
|
||||
dtype=self.base_layer.weight.dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.org_vocab_size +
|
||||
lora_config.lora_extra_vocab_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.embedding_dim,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked_2d = self.lora_a_stacked.view(
|
||||
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||
self.lora_a_stacked.shape[2],
|
||||
)
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
self.embeddings_indices = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||
shape[1]].copy_(embeddings_tensor, non_blocking=True)
|
||||
if self.embeddings_slice is not None:
|
||||
# TODO(yard1): Optimize this copy, we don't need to copy
|
||||
# everything, just the modified part
|
||||
embeddings = self.embeddings_tensors.view(
|
||||
self.embeddings_tensors.shape[0] *
|
||||
self.embeddings_tensors.shape[1],
|
||||
self.embeddings_tensors.shape[2]
|
||||
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
||||
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.embeddings_indices = embeddings_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
||||
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
|
||||
full_output = self.base_layer.forward(
|
||||
x.add_(indices * added_tokens_mask))
|
||||
|
||||
full_output_org = full_output
|
||||
if full_output.ndim == 3:
|
||||
full_output = full_output.view(
|
||||
full_output.shape[0] * full_output.shape[1], -1)
|
||||
if full_lora_a_embeddings.ndim == 3:
|
||||
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
||||
full_lora_a_embeddings.shape[0] *
|
||||
full_lora_a_embeddings.shape[1], -1)
|
||||
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0],
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
self.output_dim = self.lora_b_stacked.shape[1]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = (self.base_layer.bias
|
||||
if not self.base_layer.skip_bias_add else None)
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_, bias)
|
||||
if self.base_layer.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = (self.base_layer.bias
|
||||
if self.base_layer.skip_bias_add else None)
|
||||
return output, output_bias
|
||||
|
||||
@property
|
||||
def linear_weights(self):
|
||||
return self.base_layer.linear_weights
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
|
||||
packed together (eg. gate_proj + up_proj -> gate_up_proj).
|
||||
|
||||
This means we have 2 LoRAs, each applied to one half of the layer.
|
||||
|
||||
Both slices must have the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
n_slices = 2
|
||||
if not (len(self.base_layer.output_sizes) == n_slices
|
||||
and self.base_layer.output_sizes[0]
|
||||
== self.base_layer.output_sizes[1]):
|
||||
raise ValueError(
|
||||
"LoRAColumnParallelLinear2Slice requires 2 slices with "
|
||||
"the same size.")
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
) for _ in range(n_slices))
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0] // 2,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
) for _ in range(n_slices))
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
self.lora_a_stacked[1][index] = 0
|
||||
self.lora_b_stacked[0][index] = 0
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[0][:,
|
||||
start_idx:end_idx], lora_b[1][:,
|
||||
start_idx:end_idx]
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||
lora_a[0].T, non_blocking=True)
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||
lora_b[0].T, non_blocking=True)
|
||||
if lora_a[1] is not None:
|
||||
self.lora_a_stacked[1][
|
||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||
lora_a[1].T, non_blocking=True)
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
(self.output_dim, self.output_dim),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||
packed together in qkv proj fashion
|
||||
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||
|
||||
This means we have 3 LoRAs, each applied to one slice of the layer.
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||
self.base_layer.head_size)
|
||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
# q, k, v
|
||||
self.lora_a_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
)
|
||||
self.lora_b_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.q_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
)
|
||||
|
||||
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size)
|
||||
self.packed_indices: Optional[torch.Tensor] = None
|
||||
self.standard_indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
self.lora_b_stacked[0][index] = 0
|
||||
self.lora_a_stacked[1][index] = 0
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
self.lora_a_stacked[2][index] = 0
|
||||
self.lora_b_stacked[2][index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
|
||||
lora_b_q.T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
|
||||
lora_b_k.T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
|
||||
lora_b_v.T, non_blocking=True)
|
||||
else:
|
||||
if lora_b[0] is not None:
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||
lora_b[0].T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
|
||||
lora_b[2].T, non_blocking=True)
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||
lora_a[0].T, non_blocking=True)
|
||||
if lora_a[1] is not None:
|
||||
self.lora_a_stacked[1][
|
||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||
lora_a[1].T, non_blocking=True)
|
||||
if lora_a[2] is not None:
|
||||
self.lora_a_stacked[2][
|
||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||
lora_a[2].T, non_blocking=True)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
self.output_slices,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0],
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
if self.base_layer.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.base_layer.weight.shape[1]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_a = lora_a[start_idx:end_idx, :]
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: tensor whose last dimension is `input_size`. If
|
||||
`input_is_parallel` is set, then the last dimension
|
||||
is `input_size // tp_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# Set up backprop all-reduce.
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None else output_)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_layer.weight
|
||||
|
||||
|
||||
class SamplerWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: Sampler,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
@property
|
||||
def logits_as_hidden_states(self):
|
||||
return self.base_layer.logits_as_hidden_states
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.base_layer.vocab_size
|
||||
|
||||
@property
|
||||
def org_vocab_size(self):
|
||||
return self.base_layer.org_vocab_size
|
||||
|
||||
@property
|
||||
def include_gpu_probs_tensor(self):
|
||||
return self.base_layer.include_gpu_probs_tensor
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||
if 32000 < self.base_layer.vocab_size > 33024:
|
||||
raise ValueError(
|
||||
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
# Pad for kernel compatibility
|
||||
math.ceil(self.base_layer.vocab_size /
|
||||
lora_config.lora_vocab_padding_size) *
|
||||
lora_config.lora_vocab_padding_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.embeddings_tensors = torch.full(
|
||||
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
|
||||
fill_value=float("-inf"),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.indices = None
|
||||
self.indices_padded = None
|
||||
self.indices_len = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = float("-inf")
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||
shape[1], ] = embeddings_tensor
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = sampler_indices
|
||||
self.indices_padded = sampler_indices_padded
|
||||
self.indices_len = indices_len
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
if logits is None:
|
||||
return None
|
||||
|
||||
lora_logits = torch.empty(
|
||||
self.embeddings_tensors.shape[0] + 1,
|
||||
self.embeddings_tensors.shape[1],
|
||||
hidden_states.shape[0],
|
||||
dtype=self.embeddings_tensors.dtype,
|
||||
device=self.embeddings_tensors.device,
|
||||
)
|
||||
torch.matmul(self.embeddings_tensors,
|
||||
hidden_states.T,
|
||||
out=lora_logits[:-1])
|
||||
lora_logits[-1] = float("-inf")
|
||||
lora_logits = lora_logits.mT
|
||||
lora_logits = (lora_logits.reshape(
|
||||
lora_logits.shape[0] * lora_logits.shape[1],
|
||||
lora_logits.shape[2],
|
||||
).index_select(0,
|
||||
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
|
||||
nan=float("-inf"),
|
||||
posinf=float("inf"),
|
||||
neginf=float("-inf")))
|
||||
logits[:,
|
||||
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||
lora_logits.shape[1]] = lora_logits
|
||||
|
||||
_apply_lora(
|
||||
hidden_states,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[1]],
|
||||
logits,
|
||||
)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.base_layer.vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return type(self.base_layer).forward(self, *args, **kwargs)
|
||||
|
||||
|
||||
def from_layer(
|
||||
layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLora,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
|
||||
def from_layer_sampler(
|
||||
layer: Sampler,
|
||||
lm_head: ParallelLMHead,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> SamplerWithLoRA:
|
||||
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
|
||||
lm_head.weight.device)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
Reference in New Issue
Block a user