204 lines
6.3 KiB
Python
204 lines
6.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.config.lora import LoRAConfig
|
|
from vllm.distributed import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
|
from vllm.platforms import current_platform
|
|
|
|
from .base import BaseLayerWithLoRA
|
|
|
|
|
|
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|
"""
|
|
LoRA wrapper for LogitsProcessor, with extra logic to handle the
|
|
application of the LoRA adapter and added LoRA vocabulary.
|
|
|
|
Args:
|
|
base_layer: LogitsProcessor layer
|
|
hidden_size: hidden size of the model
|
|
dtype: data type of the model
|
|
device: device of the model
|
|
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
|
|
received from base_layer.get_sharded_to_full_mapping(). If None,
|
|
no reindexing will be done.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
base_layer: LogitsProcessor,
|
|
hidden_size: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
sharded_to_full_mapping: list[int] | None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.base_layer = base_layer
|
|
self.hidden_size = hidden_size
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.sharded_to_full_mapping = sharded_to_full_mapping
|
|
|
|
@property
|
|
def logits_as_input(self):
|
|
return self.base_layer.logits_as_input
|
|
|
|
@property
|
|
def vocab_size(self):
|
|
return self.base_layer.vocab_size
|
|
|
|
@property
|
|
def scale(self):
|
|
return self.base_layer.scale
|
|
|
|
@property
|
|
def soft_cap(self):
|
|
return self.base_layer.soft_cap
|
|
|
|
@property
|
|
def use_all_gather(self):
|
|
return self.base_layer.use_all_gather
|
|
|
|
@property
|
|
def org_vocab_size(self):
|
|
return self.base_layer.org_vocab_size
|
|
|
|
@property
|
|
def include_gpu_probs_tensor(self):
|
|
return self.base_layer.include_gpu_probs_tensor
|
|
|
|
@property
|
|
def should_modify_greedy_probs_inplace(self):
|
|
return self.base_layer.should_modify_greedy_probs_inplace
|
|
|
|
def create_lora_weights(
|
|
self,
|
|
max_loras: int,
|
|
lora_config: LoRAConfig,
|
|
model_config: PretrainedConfig | None = None,
|
|
) -> None:
|
|
# TODO: Verify if this condition can be further relaxed
|
|
if 32000 < self.base_layer.vocab_size > 257024:
|
|
raise ValueError(
|
|
"When using LoRA, vocab size must be 32000 >= vocab_size <= 257024"
|
|
)
|
|
self.lora_a_stacked = torch.zeros(
|
|
(
|
|
max_loras,
|
|
1,
|
|
lora_config.max_lora_rank,
|
|
self.hidden_size,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
self.lora_b_stacked = torch.zeros(
|
|
(
|
|
max_loras,
|
|
1,
|
|
self.base_layer.vocab_size,
|
|
lora_config.max_lora_rank,
|
|
),
|
|
dtype=lora_config.lora_dtype,
|
|
device=self.device,
|
|
)
|
|
|
|
if self.sharded_to_full_mapping is not None:
|
|
self.sharded_to_full_mapping_gpu = torch.tensor(
|
|
self.sharded_to_full_mapping, device=self.device, dtype=torch.long
|
|
)
|
|
else:
|
|
self.sharded_to_full_mapping_gpu = None
|
|
|
|
def reset_lora(self, index: int):
|
|
self.lora_a_stacked[index] = 0
|
|
self.lora_b_stacked[index] = 0
|
|
|
|
def set_lora(
|
|
self,
|
|
index: int,
|
|
lora_a: torch.Tensor | list[torch.Tensor],
|
|
lora_b: torch.Tensor | list[torch.Tensor],
|
|
):
|
|
assert isinstance(lora_a, torch.Tensor)
|
|
assert isinstance(lora_b, torch.Tensor)
|
|
self.reset_lora(index)
|
|
self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_(
|
|
lora_a, non_blocking=True
|
|
)
|
|
self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_(
|
|
lora_b, non_blocking=True
|
|
)
|
|
|
|
def _get_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
lm_head: VocabParallelEmbedding,
|
|
embedding_bias: torch.Tensor | None = None,
|
|
) -> torch.Tensor | None:
|
|
# Get the logits for the next tokens.
|
|
logits = lm_head.quant_method.apply(lm_head, hidden_states)
|
|
if embedding_bias is not None:
|
|
logits += embedding_bias
|
|
|
|
# Gather logits for TP
|
|
logits = self.base_layer._gather_logits(logits)
|
|
|
|
if logits is None:
|
|
return None
|
|
|
|
if self.sharded_to_full_mapping_gpu is not None:
|
|
# Reindex full logits tensor to ensure 1:1 mapping between
|
|
# index and token_id
|
|
# Example for:
|
|
# org_vocab_size = 4
|
|
# added_vocab_size = 2
|
|
# pad_to_size = 8
|
|
# tp_size = 2
|
|
|
|
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
|
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
|
|
|
|
# Therefore, the mapping is expected to be:
|
|
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
|
|
# we get:
|
|
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
|
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
|
|
logits = logits[:, self.sharded_to_full_mapping_gpu]
|
|
|
|
lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits(
|
|
logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0
|
|
)
|
|
|
|
if not current_platform.can_update_inplace():
|
|
logits = lora_output
|
|
|
|
# Remove paddings in vocab (if any).
|
|
logits = logits[:, : self.base_layer.vocab_size]
|
|
return logits
|
|
|
|
def forward(self, *args, **kwargs):
|
|
return type(self.base_layer).forward(self, *args, **kwargs)
|
|
|
|
@classmethod
|
|
def can_replace_layer(
|
|
cls,
|
|
source_layer: nn.Module,
|
|
lora_config: LoRAConfig,
|
|
packed_modules_list: list,
|
|
model_config: PretrainedConfig | None = None,
|
|
) -> bool:
|
|
# Special handling for the LogitsProcessor.
|
|
return False
|