# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import TYPE_CHECKING, Optional, Union import torch if TYPE_CHECKING: # avoid circuit import from vllm.lora.layers import LoRAMapping from vllm.lora.models import LongContextLoRAContext def compute_meta( token_lora_tensor: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, int, bool]: """ Get the information required for the sgmv kernel. With the features: 1. If consecutive requests in the batch use the same LoRA, this function will combine them into a single request, improving sgmv kernel inference performance. 2. At the beginning of each prefill stage inference, recalculations are needed based on the input, but only once. """ lora_indices_tensor, seq_length_tensor = torch.unique_consecutive( token_lora_tensor, return_counts=True) cum_result = torch.cumsum(seq_length_tensor, dim=0) b_seq_start_tensor = torch.zeros_like(seq_length_tensor) b_seq_start_tensor[1:].copy_(cum_result[:-1]) max_length = seq_length_tensor.max().item() token_nums = seq_length_tensor.sum().item() batch_size = lora_indices_tensor.size(0) no_lora = False # -1 means no lora should be applied. Use `no_lora` to determine whether # the current step requires LoRA. If LoRA is not needed, the prefill stage # does not need to launch the triton kernel, which can improve performance if batch_size == 1 and lora_indices_tensor == -1: no_lora = True return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor, batch_size, max_length, token_nums, no_lora) # TODO see if this can be vectorized def convert_mapping( mapping: "LoRAMapping", lora_index_to_id: list[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int, device: torch.device, long_lora_context: Optional["LongContextLoRAContext"] = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor], list[int]]: """Converts LoRAMapping to index tensors. Args: mapping: LoRAMapping mapping rows in a batch to LoRA ids. lora_index_to_id: List mapping LoRA ids to LoRA indices. max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. long_lora_context: Passed if there are long context lora in a batch. Returns: A tuple of tensors: base_indices: Tensor of shape [batch_size] mapping batch rows to LoRA indices. sampler_indices: Tensor of shape [batch_size] mapping requests to LoRA indices for sampler. For generation, this will be the same as base_indicies. For prefill, this will map requests to LoRA indices. sampler_indices_padded: Tensor of shape [batch_size] mapping requests to LoRA indices for sampler with padding. Same as sampler_indicies, but -1 is replaced with max_loras. embeddings_indices: Tensor of shape [2, batch_size] mapping requests to embedding indices. First row is for embeddings added by the LoRAs, second row is for the LoRA.lora_a embeddings. long_lora_indices: Tensor of shape [batch_size] mapping requests to RoPE offsets and rot dims for long LoRAs. None if long context lora doesn't exist. indices_len: List of lengths of the above tensors. It contains (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_indices). """ index_mapping_indices: list[int] = list(mapping.index_mapping).copy() embedding_indices = index_mapping_indices.copy() lora_indices = index_mapping_indices.copy() long_lora_offsets: Optional[torch.Tensor] = None if long_lora_context: long_lora_offsets = torch.zeros(len(index_mapping_indices), device=device, dtype=torch.long) prompt_mapping: list[int] = [ lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None for i in range(len(index_mapping_indices)): # TODO index can be slow. optimize lora_idx = (lora_index_to_id.index(index_mapping_indices[i]) if index_mapping_indices[i] > 0 else -1) embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0 lora_indices[i] = lora_idx if long_lora_context: assert long_lora_offsets is not None lora_offset: int = long_lora_context.offsets_by_lora_id.get( index_mapping_indices[i], 0) long_lora_offsets[i] = lora_offset indices_list: list[Union[list[int], torch.Tensor]] = [ index_mapping_indices, lora_indices, embedding_indices, ] if long_lora_context: assert long_lora_offsets is not None indices_list.append(long_lora_offsets) indices = torch.tensor(indices_list, dtype=torch.long, device=device) prompt_mapping_tensor = torch.tensor(prompt_mapping, dtype=torch.long, device=device) embeddings_indices = torch.stack([ indices[2] * extra_vocab_size, indices[2] * (vocab_size + extra_vocab_size), ]) embeddings_indices = torch.where(embeddings_indices == -1, max_loras - 1, embeddings_indices) base_indices = indices[1] sampler_indices = prompt_mapping_tensor sampler_indices_padded = sampler_indices.clone() sampler_indices_padded = torch.where(sampler_indices_padded == -1, max_loras - 1, sampler_indices_padded) sampler_indices_padded = torch.arange( 0, len(sampler_indices_padded), device=device, dtype=torch.long) + ( sampler_indices_padded * len(sampler_indices_padded)) long_lora_indices = None long_lora_indices_len: Optional[int] = None if long_lora_context: long_lora_indices = indices[3] long_lora_indices_len = long_lora_indices.shape[-1] # Contain length of indices tensors. Used to index into each tensor. indices_len = [ base_indices.shape[-1], sampler_indices.shape[-1], sampler_indices_padded.shape[-1], embeddings_indices.shape[-1], ] if long_lora_indices_len is not None: indices_len.append(long_lora_indices_len) else: # If long_lora doesn't exist,append None indices_len.append(None) return ( base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_indices, indices_len, )