# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from typing import Optional import torch import torch.nn as nn from transformers import PretrainedConfig from vllm.config.lora import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.platforms import current_platform from .base import BaseLayerWithLoRA class LogitsProcessorWithLoRA(BaseLayerWithLoRA): """ LoRA wrapper for LogitsProcessor, with extra logic to handle the application of the LoRA adapter and added LoRA vocabulary. Args: base_layer: LogitsProcessor layer hidden_size: hidden size of the model dtype: data type of the model device: device of the model sharded_to_full_mapping: index mapping from sharded vocab to full vocab received from base_layer.get_sharded_to_full_mapping(). If None, no reindexing will be done. """ def __init__(self, base_layer: LogitsProcessor, hidden_size: int, dtype: torch.dtype, device: torch.device, sharded_to_full_mapping: Optional[list[int]]) -> None: super().__init__() self.base_layer = base_layer self.hidden_size = hidden_size self.dtype = dtype self.device = device self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.sharded_to_full_mapping = sharded_to_full_mapping @property def logits_as_input(self): return self.base_layer.logits_as_input @property def vocab_size(self): return self.base_layer.vocab_size @property def scale(self): return self.base_layer.scale @property def soft_cap(self): return self.base_layer.soft_cap @property def use_all_gather(self): return self.base_layer.use_all_gather @property def org_vocab_size(self): return self.base_layer.org_vocab_size @property def include_gpu_probs_tensor(self): return self.base_layer.include_gpu_probs_tensor @property def should_modify_greedy_probs_inplace(self): return self.base_layer.should_modify_greedy_probs_inplace def create_lora_weights( self, max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: # TODO: Verify if this condition can be further relaxed if 32000 < self.base_layer.vocab_size > 257024: raise ValueError("When using LoRA, vocab size must be " "32000 >= vocab_size <= 257024") self.lora_a_stacked = torch.zeros( ( max_loras, 1, lora_config.max_lora_rank, self.hidden_size, ), dtype=lora_config.lora_dtype, device=self.device, ) self.lora_b_stacked = torch.zeros( ( max_loras, 1, # Pad for kernel compatibility math.ceil(self.base_layer.vocab_size / lora_config.lora_vocab_padding_size) * lora_config.lora_vocab_padding_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, device=self.device, ) self.embeddings_tensors = torch.full( (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), fill_value=float("-inf"), dtype=self.dtype, device=self.device, ) if self.sharded_to_full_mapping is not None: self.sharded_to_full_mapping_gpu = torch.tensor( self.sharded_to_full_mapping, device=self.device, dtype=torch.long) else: self.sharded_to_full_mapping_gpu = None def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 self.embeddings_tensors[index] = float("-inf") def set_lora( self, index: int, lora_a: torch.Tensor, lora_b: torch.Tensor, embeddings_tensor: Optional[torch.Tensor], bias: Optional[torch.Tensor] = None, ): self.reset_lora(index) self.lora_a_stacked[index, 0, :lora_a.shape[0], :lora_a.shape[1]].copy_( lora_a, non_blocking=True) self.lora_b_stacked[index, 0, :lora_b.shape[0], :lora_b.shape[1]].copy_( lora_b, non_blocking=True) if embeddings_tensor is not None: self.embeddings_tensors[ index, :embeddings_tensor.shape[0], :embeddings_tensor.shape[1], ] = embeddings_tensor def _get_logits( self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, embedding_bias: Optional[torch.Tensor] = None, ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. logits = lm_head.quant_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias # Gather logits for TP logits = self.base_layer._gather_logits(logits) if logits is None: return None if self.sharded_to_full_mapping_gpu is not None: # Reindex full logits tensor to ensure 1:1 mapping between # index and token_id # Example for: # org_vocab_size = 4 # added_vocab_size = 2 # pad_to_size = 8 # tp_size = 2 # indices: [0, 1, 2, 3, 4, 5, 6, 7] # token_id: [0, 1, 4, -1, 2, 3, 5, -1] # Therefore, the mapping is expected to be: # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, # we get: # indices: [0, 1, 2, 3, 4, 5, 6, 7] # token_id: [0, 1, 2, 3, 4, 5, -1, -1] logits = logits[:, self.sharded_to_full_mapping_gpu] lora_logits = torch.empty( self.embeddings_tensors.shape[0] + 1, self.embeddings_tensors.shape[1], hidden_states.shape[0], dtype=self.embeddings_tensors.dtype, device=self.embeddings_tensors.device, ) torch.matmul(self.embeddings_tensors, hidden_states.T, out=lora_logits[:-1]) neg_inf, pos_inf = current_platform.get_infinity_values( lora_logits.dtype) lora_logits[-1] = neg_inf lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded if current_platform.is_tpu() or current_platform.is_xpu(): indices_padded = indices_padded[:logits.size(0)] lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], ).index_select(0, indices_padded).nan_to_num_(nan=neg_inf, posinf=pos_inf, neginf=neg_inf)) logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + lora_logits.shape[1]] = lora_logits lora_output: Optional[ torch.Tensor] = self.punica_wrapper.add_lora_logits( logits, hidden_states, self.lora_a_stacked, self.lora_b_stacked, 1.0) if not current_platform.can_update_inplace(): logits = lora_output # Remove paddings in vocab (if any). logits = logits[:, :self.base_layer.vocab_size] return logits def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) @classmethod def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: Optional[PretrainedConfig], ) -> bool: # Special handling for the LogitsProcessor. return False