# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc import functools from abc import abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar import numpy as np import torch from vllm.utils import cdiv if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch import vllm.envs as envs from vllm.distributed.kv_transfer.kv_connector.utils import ( get_kv_connector_cache_layout) from vllm.logger import init_logger logger = init_logger(__name__) @dataclass class CommonAttentionMetadata: """ Per-batch attention metadata, shared across layers and backends. AttentionMetadataBuilder instances use it to construct per-layer metadata. """ query_start_loc: torch.Tensor """(batch_size + 1,), the start location of each request in query Tensor""" seq_lens: torch.Tensor """(batch_size,), the length of each request including both computed tokens and newly scheduled tokens""" num_reqs: int """Number of requests""" num_actual_tokens: int """Total number of tokens in batch""" max_query_len: int """Longest query in batch""" num_speculative_tokens: int = 0 """Number of speculative tokens""" slot_mapping: torch.Tensor = None """(batch_size, seq_len), slot mapping""" spec_layer_decoding: bool = False M = TypeVar("M") class AttentionMetadataBuilder(abc.ABC, Generic[M]): # Does this backend/builder support CUDA Graphs for attention. full_cudagraph_supported: ClassVar[bool] = False @abstractmethod def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. """ raise NotImplementedError def can_run_in_cudagraph( self, common_attn_metadata: CommonAttentionMetadata) -> bool: """ Can this batch (with given metadata) use CUDA Graphs for attention. """ return False def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata) -> M: """ Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture. """ return self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata) def use_cascade_attention( self, common_prefix_len: int, query_lens: np.ndarray, num_query_heads: int, num_kv_heads: int, use_alibi: bool, use_sliding_window: bool, num_sms: int, ) -> bool: return False def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: """ This method can reorder the batch if desired by the backend. :return: Has the batch been reordered (default False). """ return False def validate_kv_sharing_target(current_layer_name, target_layer_name, static_forward_context): error_msg = (f"Specified KV sharing target layer for {current_layer_name} " f"is not valid: target layer {target_layer_name} ") if current_layer_name == target_layer_name: raise ValueError(error_msg + "cannot be the same as the current layer.") if target_layer_name not in static_forward_context: from vllm.model_executor.models.utils import extract_layer_index # If target layer name is not in the static fwd context, it means either # a) the target layer does not come BEFORE the current layer, or # b) the target layer is not an Attention layer that exists in the model current_layer_idx = extract_layer_index(current_layer_name) target_layer_idx = extract_layer_index(target_layer_name) if current_layer_idx <= target_layer_idx: raise ValueError(error_msg + "must come before the current layer.") else: raise ValueError(error_msg + "is not a valid Attention layer in the model.") # Currently KV sharing is only supported between layers of the same type target_layer_attn_type = static_forward_context[ target_layer_name].attn_type expected = static_forward_context[current_layer_name].attn_type if target_layer_attn_type != expected: raise ValueError( error_msg + f"must be the same type as the current layer ({expected}).") @functools.lru_cache def get_kv_cache_layout(): # Override with format specified by the user. cache_layout = envs.VLLM_KV_CACHE_LAYOUT if cache_layout is None: cache_layout = get_kv_connector_cache_layout() else: logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \ "detected. Setting KV cache layout to %s.", cache_layout) return cache_layout # # Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into # local attention blocks, where each block is passed to the attention kernel # as an independent local ("virtual") batch item. # # For example, if are performing a chunked prefill a batch of 3 sequences: # q_seqlens = [4, 10, 5] # kv_seqlens = [6, 17, 9] # Then normally for regular attention we would compute with an attention mask # for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like: # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6) # k_toks > 0 1 2 3 4 5 # q_toks v _____________ # 0 | 1 1 1 # 1 | 1 1 1 1 # 2 | 1 1 1 1 1 # 3 | 1 1 1 1 1 1 # # for local attention (with attn_chunk_size = 4) we would compute with an # attention mask like: # batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4) # k_toks > 0 1 2 3 4 5 # q_toks v _____________ # 0 | 1 1 1 # 1 | 1 1 1 1 # 2 | 1 # 3 | 1 1 # # We can simulate this mask using standard flash-attention by breaking the # sequences into local ("virtual") batches, where each local batch item is a # local attention block, so in this case batch idx 0 would be broken up into: # # local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0) # k_toks > 0 1 2 3 # q_toks v _____________ # 0 | 1 1 1 # 1 | 1 1 1 1 # local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0) # k_toks > 4 5 # q_toks v _____________ # 2 | 1 # 3 | 1 1 # # e.g. if we have: # attn_chunk_size = 4 # query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5]) # Then this function would return: # __b0__ ______b1______ __b2__ < orig batch indices # q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1] # cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24] # seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1] # block_table_local : shape[local_virtual_batches, pages_per_local_batch] def make_local_attention_virtual_batches( attn_chunk_size: int, query_start_loc_np: np.ndarray, seq_lens_np: np.ndarray, block_table: torch.Tensor, block_size: int = 0, ) -> tuple[np.ndarray, np.ndarray, np.ndarray, torch.Tensor]: q_seqlens = query_start_loc_np[1:] - query_start_loc_np[:-1] actual_batch_size = seq_lens_np.shape[0] # Handle if we are starting in the middle of a local attention block, # we assume q_seqlens > 0 (for all elements), for each batch idx we compute # the number of tokens that are not in the first local attention block and # then we can simply use a cdiv for the rest. # For example if we have: # attn_chunk_size = 4 # q_seqlens = [4, 10, 5] # k_seqlens = [6, 17, 9] # Then we would get: # new_tokens_in_first_block = [2, 1, 4] # local_blocks = [2, 4, 2] q_tokens_in_first_block = np.minimum( attn_chunk_size - ((seq_lens_np - q_seqlens) % attn_chunk_size), q_seqlens).astype(np.int32) tokens_in_last_block = attn_chunk_size + (seq_lens_np % -attn_chunk_size) local_blocks = 1 + cdiv(q_seqlens - q_tokens_in_first_block, attn_chunk_size) # Once we know the number of local blocks we can compute the request spans # for each batch idx, we can figure out the number of "virtual" requests we # have to make, # For the above example we would get: # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] # # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) # (TODO: max a utility to share this code with _prepare_inputs) # arange step 1. [2, 4, 2] -> [2, 6, 8] cu_num_blocks = np.cumsum(local_blocks) virtual_batches = cu_num_blocks[-1] # arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6] block_offsets = np.repeat(cu_num_blocks - local_blocks, local_blocks) # arange step 3. [0, 1, 0, 1, 2, 3, 0, 1] arange = np.arange(virtual_batches, dtype=np.int32) - block_offsets # also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0]) rarange = np.repeat(local_blocks, local_blocks) - arange - 1 # Then we can compute the seqlens_q_local, handling the fact that the # first and last blocks could be partial seqlens_q_local = \ np.repeat(q_seqlens - q_tokens_in_first_block, local_blocks) # set the first block since this may be a partial block seqlens_q_local[arange == 0] = q_tokens_in_first_block # set the remaining blocks seqlens_q_local[arange > 0] = np.minimum( seqlens_q_local - attn_chunk_size * (arange - 1), attn_chunk_size)[arange > 0] # convert from q_seqlens to cu_seqlens_q cu_seqlens_q_local = np.pad(np.cumsum(seqlens_q_local), (1, 0))\ .astype(np.int32) # compute the seqlens_k_local, # basically a full local attention block for all but the last block in each # batch # For our example this will be: # seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1] seqlens_k_local = np.full(cu_num_blocks[-1], attn_chunk_size, dtype=np.int32) seqlens_k_local[cu_num_blocks - 1] = tokens_in_last_block k_seqstarts_absolute = np.repeat(seq_lens_np, local_blocks) - \ (rarange * attn_chunk_size + \ np.repeat(tokens_in_last_block, local_blocks)) # For the example the local attention blocks start at: # _b0_ _____b1_____ _b2_ # k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8] block_starts = k_seqstarts_absolute // block_size assert attn_chunk_size % block_size == 0, \ f"attn_chunk_size {attn_chunk_size} is not " \ f"divisible by block_size {block_size}" pages_per_local_batch = attn_chunk_size // block_size # Create a block_table for the local attention blocks # For out example if we have a block-table like (assuming block_size=2): # block_table = [ # [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0 # [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1 # [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2 # ] # Then for the local batches we would want a block-table like # block_table_local = [ # [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0]) # [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4]) # [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4]) # [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8]) # [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12]) # [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16]) # [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4]) # [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8]) # ] block_indices= np.broadcast_to( np.arange(pages_per_local_batch, dtype=np.int32), (virtual_batches, pages_per_local_batch)) \ + np.expand_dims(block_starts, axis=1) block_indices = block_indices.flatten().clip(max=block_table.shape[1] - 1) batch_indices = np.repeat(np.arange(actual_batch_size, dtype=np.int32), local_blocks * pages_per_local_batch) block_table_local = block_table[batch_indices, block_indices]\ .view(virtual_batches, -1) return seqlens_q_local, cu_seqlens_q_local, seqlens_k_local, \ block_table_local