# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import abc from dataclasses import dataclass, replace from typing import Any, ClassVar, TypeVar import torch from vllm.config import VllmConfig from vllm.utils.math_utils import cdiv from vllm.v1.attention.backend import ( AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, ) from vllm.v1.attention.backends.utils import ( PAD_SLOT_ID, compute_causal_conv1d_metadata, mamba_get_block_table_tensor, split_decodes_and_prefills, ) from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec M = TypeVar("M", bound="BaseMambaAttentionMetadata") @dataclass class BaseMambaAttentionMetadata: num_prefills: int num_prefill_tokens: int num_decodes: int num_decode_tokens: int num_reqs: int # The following tensors only contain prefill requests and will be None if # the batch has no prefill requests. has_initial_states_p: torch.Tensor | None query_start_loc_p: torch.Tensor | None num_computed_tokens_p: torch.Tensor | None state_indices_tensor_p: torch.Tensor | None # The following tensors are used for decode requests and # speculative decoding compatibility, and will be None if the batch # has no decode requests. state_indices_tensor_d: torch.Tensor | None query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,] # Number of accepted tokens for each spec sequence (for loading correct checkpoint) # Includes the bonus token (so minimum is 1) num_accepted_tokens: torch.Tensor | None # shape: [batch,] # The following tensors are only used for prefix caching in all mode and # are None if disabled block_idx_last_scheduled_token: torch.Tensor | None block_idx_first_scheduled_token_p: torch.Tensor | None block_idx_last_computed_token: torch.Tensor | None # The following tensor is only used for prefix caching in align mode seq_lens: torch.Tensor # cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for # each chunk, its offsets into the varlen sequence dimension. It is defined # such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to # cu_chunk_seqlen_p[i+1]. cu_chunk_seqlen_p: torch.Tensor | None = None # last_chunk_indices_p is a tensor of shape (batch,) that contains the # index of the last chunk for every sequence in the (prefill) batch. last_chunk_indices_p: torch.Tensor | None = None # The following attributes are for triton implementation of causal_conv1d nums_dict: dict | None = None batch_ptr: torch.Tensor | None = None token_chunk_offset_ptr: torch.Tensor | None = None class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC): metadata_cls: type[M] reorder_batch_threshold: int = 1 _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH # Will be disabled if speculative decoding is used supports_update_block_table: bool = True def __init__( self, kv_cache_spec: AttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, ): super().__init__(kv_cache_spec, layer_names, vllm_config, device) # Enable speculative decoding support self.speculative_config = vllm_config.speculative_config self.compilation_config = vllm_config.compilation_config self.num_spec_tokens: int = vllm_config.num_speculative_tokens self.use_spec_decode = self.num_spec_tokens > 0 assert isinstance(kv_cache_spec, MambaSpec) self.compilation_config = vllm_config.compilation_config self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs if self.compilation_config.max_cudagraph_capture_size is not None: self.decode_cudagraph_max_bs = min( self.decode_cudagraph_max_bs, self.compilation_config.max_cudagraph_capture_size, ) if self.vllm_config.cache_config.mamba_cache_mode == "all": max_num_blocks = cdiv( self.vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size, ) # Speculative decoding not supported with prefix caching, # so keep shape consistent with prefill buffer # TODO: reduce this size as needed for decode-only cudagraph capture self.state_indices_tensor_d = torch.empty( ( self.decode_cudagraph_max_bs, max_num_blocks, ), dtype=torch.int32, device=device, ) self.block_idx_last_scheduled_token = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) self.block_idx_last_computed_token = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) else: self.state_indices_tensor_d = torch.empty( (self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens), dtype=torch.int32, device=device, ) # For speculative decoding, we need to store the following buffers # for CUDA graph capture during decode if self.num_spec_tokens > 0: self.decode_num_accepted_tokens = torch.empty( (self.decode_cudagraph_max_bs,), dtype=torch.int32, device=device, ) self._init_reorder_batch_threshold(1, self.use_spec_decode) if self.use_spec_decode: self.supports_update_block_table = False def build_for_cudagraph_capture( self, common_attn_metadata: CommonAttentionMetadata ) -> M: """ This method builds the metadata for full cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba. """ m = common_attn_metadata assert ( m.max_query_len <= 1 + self.num_spec_tokens and m.num_reqs <= self.decode_cudagraph_max_bs ), ( "Mamba only supports decode-only full CUDAGraph capture. " "Make sure all cudagraph capture sizes <= max_num_seq." ) assert m.max_query_len == 1 + self.num_spec_tokens # decode-only num_accepted_tokens = None if self.num_spec_tokens > 0: num_accepted_tokens = torch.diff(m.query_start_loc) return self.build(0, m, num_accepted_tokens=num_accepted_tokens) def build( self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, fast_build: bool = False, *, num_accepted_tokens: torch.Tensor | None = None, **kwargs: Any, ) -> M: """ Default build implementation for Mamba-like attention backends. Subclasses (e.g., Mamba2) can override to add additional metadata. """ return self._compute_common_metadata( common_attn_metadata, num_accepted_tokens=num_accepted_tokens ) def _compute_chunk_metadata( self, chunk_size: int, num_prefills: int, num_computed_tokens_p_cpu: torch.Tensor, query_start_loc_p_cpu: torch.Tensor, ) -> tuple[list[int], list[int], list[int]]: """ Compute chunk-specific metadata for Mamba models. The code below carefully constructs the chunks such that: 1. Chunks contain tokens from a *single* sequence only. 2. For every sequence, we are guaranteed that we can retrieve the mamba state *every* chunk_size tokens. Constraint (1) dramatically simplifies the mamba kernels. Constraint (2) dramatically simplifies the implementation of prefix caching for mamba (wip). We need to take care of the interaction with chunked prefill in order to satisfy constraint (2). """ # TODO (tdoublep): This code could probably be optimized. cu_chunk_seqlen = [] seq_idx = [] last_chunk_indices = [] seqlen_pos = 0 for req_idx in range(num_prefills): this_num_computed = num_computed_tokens_p_cpu[req_idx].item() this_new_tokens = ( query_start_loc_p_cpu[req_idx + 1].item() - query_start_loc_p_cpu[req_idx].item() ) # if computed tokens are not chunk-aligned, use the first # chunk to finish it off if this_num_computed % chunk_size != 0: seq_idx.append(req_idx) cu_chunk_seqlen.append(seqlen_pos) # how many tokens to finish the chunk? chunk_len = ( cdiv(this_num_computed, chunk_size) * chunk_size - this_num_computed ) # we can only use at most this_new_tokens chunk_len = min(chunk_len, this_new_tokens) seqlen_pos += chunk_len this_new_tokens -= chunk_len n_chunks = cdiv(this_new_tokens, chunk_size) for chunk in range(n_chunks): seq_idx.append(req_idx) cu_chunk_seqlen.append(seqlen_pos) chunk_len = min(chunk_size, this_new_tokens) seqlen_pos += chunk_len this_new_tokens -= chunk_len assert this_new_tokens == 0 last_chunk_indices.append(len(cu_chunk_seqlen) - 1) cu_chunk_seqlen.append(seqlen_pos) return cu_chunk_seqlen, seq_idx, last_chunk_indices def _build_chunk_metadata_tensors( self, chunk_size: int, common: M, common_attn_metadata: CommonAttentionMetadata, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Compute chunk metadata and return as device tensors. Returns (cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p). """ num_reqs = common.num_reqs num_prefills = common.num_prefills num_decode_tokens = common.num_decode_tokens num_computed_tokens_cpu = ( common_attn_metadata.compute_num_computed_tokens().cpu() ) num_computed_tokens_p_cpu = num_computed_tokens_cpu[ num_reqs - num_prefills : num_reqs ] query_start_loc_p_cpu = ( common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] - num_decode_tokens ) cu_chunk_seqlen, seq_idx, last_chunk_indices = self._compute_chunk_metadata( chunk_size, num_prefills, num_computed_tokens_p_cpu, query_start_loc_p_cpu, ) device = common_attn_metadata.query_start_loc.device cu_chunk_seqlen_p = torch.as_tensor( cu_chunk_seqlen, device=device, dtype=torch.int32, ) seq_idx_p = torch.as_tensor( seq_idx, device=device, dtype=torch.int32, ) last_chunk_indices_p = torch.as_tensor( last_chunk_indices, device=device, dtype=torch.int32, ) return cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p def _compute_prefix_caching_block_indices( self, common_attn_metadata: CommonAttentionMetadata, mamba_block_size: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() # Block index of the last computed token block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1 # which is <= block index for the first scheduled token block_idx_first_scheduled_token = ( cdiv(num_computed_tokens + 1, mamba_block_size) - 1 ) # which is <= block index of the last scheduled token block_idx_last_scheduled_token = ( cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1 ) # -1 in case it's non-computed and causes later issues with indexing block_idx_last_computed_token = torch.clamp( block_idx_last_computed_token, min=0 ) # -1 in the case we have a padded request (0 seq-len) block_idx_last_scheduled_token = torch.clamp( block_idx_last_scheduled_token, min=0 ) return ( block_idx_last_computed_token, block_idx_first_scheduled_token, block_idx_last_scheduled_token, ) def _compute_common_metadata( self, common_attn_metadata: CommonAttentionMetadata, *, num_accepted_tokens: torch.Tensor | None = None, ) -> M: """ Compute metadata common to both Mamba1 and Mamba2. """ num_reqs = common_attn_metadata.num_reqs # Treat multi-token queries as decode requests when # speculative decoding is enabled. Otherwise, use the # default decode threshold to prevent misclassification # of prefill queries as decode requests. decode_threshold = ( self.reorder_batch_threshold if num_accepted_tokens is not None else 1 ) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( split_decodes_and_prefills( common_attn_metadata, decode_threshold=decode_threshold ) ) # Need flags to indicate if there are initial states has_initial_states_p = None query_start_loc_p = None query_start_loc_d = None num_computed_tokens = None num_computed_tokens_p = None # for prefix caching block_idx_first_scheduled_token = None block_idx_first_scheduled_token_p = None block_idx_last_computed_token = None block_idx_last_scheduled_token = None # for causal_conv1d nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None if self.vllm_config.cache_config.mamba_cache_mode == "all": num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() # Return a tensor of shape (#requests, #max blocks) state_indices_tensor = common_attn_metadata.block_table_tensor # Additional cache-related varaiables: mamba_block_size = self.kv_cache_spec.block_size ( block_idx_last_computed_token, block_idx_first_scheduled_token, block_idx_last_scheduled_token, ) = self._compute_prefix_caching_block_indices( common_attn_metadata, mamba_block_size ) else: state_indices_tensor = mamba_get_block_table_tensor( common_attn_metadata.block_table_tensor, common_attn_metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, ) if state_indices_tensor.dim() == 1: state_indices_tensor = state_indices_tensor.unsqueeze(-1) state_indices_tensor_d, state_indices_tensor_p = torch.split( state_indices_tensor, [num_decodes, num_prefills], dim=0, ) if self.vllm_config.cache_config.mamba_cache_mode != "all": state_indices_tensor_d = state_indices_tensor_d[ :, : 1 + self.num_spec_tokens ] state_indices_tensor_p = state_indices_tensor_p[:, 0] if num_decodes > 0 and self.use_spec_decode: assert num_accepted_tokens is not None query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1] num_accepted_tokens = num_accepted_tokens[:num_decodes] if num_prefills > 0: if num_computed_tokens is None: num_computed_tokens = common_attn_metadata.compute_num_computed_tokens() query_start_loc_p_cpu = ( common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :] - num_decode_tokens ) query_start_loc_p = ( common_attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decode_tokens ) has_initial_states_p = ( num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0 ) nums_dict, batch_ptr, token_chunk_offset_ptr = ( compute_causal_conv1d_metadata( query_start_loc_p_cpu, device=common_attn_metadata.query_start_loc.device, ) ) if self.vllm_config.cache_config.mamba_cache_mode == "all": assert num_computed_tokens is not None num_computed_tokens_p = num_computed_tokens[ num_reqs - num_prefills : num_reqs ] assert block_idx_first_scheduled_token is not None block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[ num_reqs - num_prefills : num_reqs ] metadata = self.metadata_cls( num_prefills=num_prefills, num_prefill_tokens=num_prefill_tokens, num_decodes=num_decodes, num_decode_tokens=num_decode_tokens, query_start_loc_p=query_start_loc_p, has_initial_states_p=has_initial_states_p, state_indices_tensor_p=state_indices_tensor_p, state_indices_tensor_d=state_indices_tensor_d, num_accepted_tokens=num_accepted_tokens, query_start_loc_d=query_start_loc_d, block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p, block_idx_last_computed_token=block_idx_last_computed_token, num_computed_tokens_p=num_computed_tokens_p, num_reqs=num_reqs, seq_lens=common_attn_metadata.seq_lens, nums_dict=nums_dict, batch_ptr=batch_ptr, token_chunk_offset_ptr=token_chunk_offset_ptr, ) return self._update_metadata_for_cudagraph_capture(metadata) def _update_metadata_for_cudagraph_capture( self, metadata: M, ) -> M: """ Update the metadata for cudagraph capture. Currently, only decode is supported for full cudagraphs with Mamba. """ state_indices_tensor_d = metadata.state_indices_tensor_d query_start_loc_d = metadata.query_start_loc_d num_accepted_tokens = metadata.num_accepted_tokens block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token block_idx_last_computed_token = metadata.block_idx_last_computed_token if ( metadata.num_prefills == 0 and metadata.num_decodes <= self.decode_cudagraph_max_bs and self.compilation_config.cudagraph_mode.has_full_cudagraphs() ): padded_bs = metadata.num_reqs self.state_indices_tensor_d[: metadata.num_decodes].copy_( state_indices_tensor_d, non_blocking=True ) state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs] state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID if self.use_spec_decode: assert query_start_loc_d is not None assert num_accepted_tokens is not None query_start_loc_d = query_start_loc_d[: padded_bs + 1] self.decode_num_accepted_tokens[: metadata.num_decodes].copy_( num_accepted_tokens, non_blocking=True ) num_accepted_tokens = self.decode_num_accepted_tokens[:padded_bs] num_accepted_tokens[metadata.num_decodes :] = ( 1 # pad with 1st slot index ) if self.vllm_config.cache_config.mamba_cache_mode == "all": assert block_idx_last_scheduled_token is not None assert block_idx_last_computed_token is not None self.block_idx_last_scheduled_token[: metadata.num_decodes].copy_( block_idx_last_scheduled_token[: metadata.num_decodes], non_blocking=True, ) block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[ : metadata.num_decode_tokens ] self.block_idx_last_computed_token[: metadata.num_decodes].copy_( block_idx_last_computed_token[: metadata.num_decodes], non_blocking=True, ) block_idx_last_computed_token = self.block_idx_last_computed_token[ : metadata.num_decode_tokens ] return replace( metadata, state_indices_tensor_d=state_indices_tensor_d, query_start_loc_d=query_start_loc_d, num_accepted_tokens=num_accepted_tokens, block_idx_last_scheduled_token=block_idx_last_scheduled_token, block_idx_last_computed_token=block_idx_last_computed_token, ) def update_block_table( self, metadata: M, blk_table: torch.Tensor, slot_mapping: torch.Tensor, ) -> M: state_indices_tensor = mamba_get_block_table_tensor( blk_table, metadata.seq_lens, self.kv_cache_spec, self.vllm_config.cache_config.mamba_cache_mode, ) if state_indices_tensor.dim() == 1: state_indices_tensor = state_indices_tensor.unsqueeze(-1) assert ( metadata.num_prefills + metadata.num_decodes == state_indices_tensor.shape[0] ), ( "Mismatch in number of requests when updating block table." f" Expected {metadata.num_prefills + metadata.num_decodes}, " f"got {state_indices_tensor.shape[0]}." ) state_indices_tensor_d, state_indices_tensor_p = torch.split( state_indices_tensor, [metadata.num_decodes, metadata.num_prefills], dim=0, ) if self.vllm_config.cache_config.mamba_cache_mode != "all": state_indices_tensor_d = state_indices_tensor_d[ :, : 1 + self.num_spec_tokens ] state_indices_tensor_p = state_indices_tensor_p[:, 0] new_metadata = replace( metadata, state_indices_tensor_d=state_indices_tensor_d, state_indices_tensor_p=state_indices_tensor_p, ) return self._update_metadata_for_cudagraph_capture(new_metadata)