# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import torch import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.platforms import current_platform from .base import BaseLayerWithLoRA class LogitsProcessorWithLoRA(BaseLayerWithLoRA): """ LoRA wrapper for LogitsProcessor, with extra logic to handle the application of the LoRA adapter and added LoRA vocabulary. Args: base_layer: LogitsProcessor layer hidden_size: hidden size of the model dtype: data type of the model device: device of the model sharded_to_full_mapping: index mapping from sharded vocab to full vocab received from base_layer.get_sharded_to_full_mapping(). If None, no reindexing will be done. """ def __init__( self, base_layer: LogitsProcessor, hidden_size: int, dtype: torch.dtype, device: torch.device, sharded_to_full_mapping: list[int] | None, ) -> None: super().__init__() self.base_layer = base_layer self.hidden_size = hidden_size self.dtype = dtype self.device = device self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.sharded_to_full_mapping = sharded_to_full_mapping @property def logits_as_input(self): return self.base_layer.logits_as_input @property def vocab_size(self): return self.base_layer.vocab_size @property def scale(self): return self.base_layer.scale @property def soft_cap(self): return self.base_layer.soft_cap @property def use_all_gather(self): return self.base_layer.use_all_gather @property def org_vocab_size(self): return self.base_layer.org_vocab_size @property def include_gpu_probs_tensor(self): return self.base_layer.include_gpu_probs_tensor @property def should_modify_greedy_probs_inplace(self): return self.base_layer.should_modify_greedy_probs_inplace def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: PretrainedConfig | None = None, ) -> None: # TODO: Verify if this condition can be further relaxed if 32000 < self.base_layer.vocab_size > 257024: raise ValueError( "When using LoRA, vocab size must be 32000 >= vocab_size <= 257024" ) self.lora_a_stacked = torch.zeros( ( max_loras, 1, lora_config.max_lora_rank, self.hidden_size, ), dtype=lora_config.lora_dtype, device=self.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, # Pad for kernel compatibility math.ceil( self.base_layer.vocab_size / lora_config.lora_vocab_padding_size ) * lora_config.lora_vocab_padding_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.device, ) self.embeddings_tensors = torch.full( (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), fill_value=float("-inf"), dtype=self.dtype, device=self.device, ) if self.sharded_to_full_mapping is not None: self.sharded_to_full_mapping_gpu = torch.tensor( self.sharded_to_full_mapping, device=self.device, dtype=torch.long ) else: self.sharded_to_full_mapping_gpu = None def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 self.embeddings_tensors[index] = float("-inf") def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: torch.Tensor | None, ): self.reset_lora(index) self.lora_a_stacked[index, 0, : lora_a.shape[0], : lora_a.shape[1]].copy_( lora_a, non_blocking=True ) self.lora_b_stacked[index, 0, : lora_b.shape[0], : lora_b.shape[1]].copy_( lora_b, non_blocking=True ) if embeddings_tensor is not None: self.embeddings_tensors[ index, : embeddings_tensor.shape[0], : embeddings_tensor.shape[1], ] = embeddings_tensor def _get_logits( self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, embedding_bias: torch.Tensor | None = None, ) -> torch.Tensor | None: # Get the logits for the next tokens. logits = lm_head.quant_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias # Gather logits for TP logits = self.base_layer._gather_logits(logits) if logits is None: return None if self.sharded_to_full_mapping_gpu is not None: # Reindex full logits tensor to ensure 1:1 mapping between # index and token_id # Example for: # org_vocab_size = 4 # added_vocab_size = 2 # pad_to_size = 8 # tp_size = 2 # indices: [0, 1, 2, 3, 4, 5, 6, 7] # token_id: [0, 1, 4, -1, 2, 3, 5, -1] # Therefore, the mapping is expected to be: # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, # we get: # indices: [0, 1, 2, 3, 4, 5, 6, 7] # token_id: [0, 1, 2, 3, 4, 5, -1, -1] logits = logits[:, self.sharded_to_full_mapping_gpu] lora_logits = torch.empty( self.embeddings_tensors.shape[0] + 1, self.embeddings_tensors.shape[1], hidden_states.shape[0], dtype=self.embeddings_tensors.dtype, device=self.embeddings_tensors.device, ) torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) neg_inf, pos_inf = current_platform.get_infinity_values(lora_logits.dtype) lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded if current_platform.is_tpu() or current_platform.is_xpu(): indices_padded = indices_padded[: logits.size(0)] lora_logits = ( lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], ) .index_select(0, indices_padded) .nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf) ) logits[ :, self.base_layer.org_vocab_size : self.base_layer.org_vocab_size + lora_logits.shape[1], ] = lora_logits lora_output: torch.Tensor | None = self.punica_wrapper.add_lora_logits( logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0 ) if not current_platform.can_update_inplace(): logits = lora_output # Remove paddings in vocab (if any). logits = logits[:, : self.base_layer.vocab_size] return logits def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None, ) -> bool: # Special handling for the LogitsProcessor. return False