Files
enginex-mthreads-vllm/vllm/v1/attention/backends/mamba_attn.py
2026-01-19 10:38:50 +08:00

118 lines
4.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
from typing import ClassVar, TypeVar
import torch
from vllm.config import VllmConfig
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
M = TypeVar("M")
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
reorder_batch_threshold: int = 1
_cudagraph_support: ClassVar[AttentionCGSupport] = (
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
)
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
assert isinstance(kv_cache_spec, MambaSpec)
self.compilation_config = vllm_config.compilation_config
self.decode_cudagraph_max_bs = min(
self.vllm_config.scheduler_config.max_num_seqs,
self.compilation_config.max_cudagraph_capture_size,
)
if self.vllm_config.cache_config.enable_prefix_caching:
self.state_indices_tensor = torch.empty(
(
self.decode_cudagraph_max_bs,
cdiv(
self.vllm_config.model_config.max_model_len,
self.kv_cache_spec.block_size,
),
),
dtype=torch.int32,
device=device,
)
self.block_idx_last_scheduled_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
self.block_idx_last_computed_token = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
else:
self.state_indices_tensor = torch.empty(
(self.decode_cudagraph_max_bs,),
dtype=torch.int32,
device=device,
)
def build_for_cudagraph_capture(
self, common_attn_metadata: CommonAttentionMetadata
) -> M:
"""
This method builds the metadata for full cudagraph capture.
Currently, only decode is supported for full cudagraphs with Mamba.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, (
"Mamba only supports decode-only full CUDAGraph capture. "
"Make sure all cudagraph capture sizes <= max_num_seq."
)
m.max_query_len = 1 # decode-only
return self.build(0, m)
def _compute_prefix_caching_block_indices(
self,
common_attn_metadata: CommonAttentionMetadata,
mamba_block_size: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
self.device
)
# Block index of the last computed token
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
# which is <= block index for the first scheduled token
block_idx_first_scheduled_token = (
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
)
# which is <= block index of the last scheduled token
block_idx_last_scheduled_token = (
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
)
# -1 in case it's non-computed and causes later issues with indexing
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
# -1 in the case we have a padded request (0 seq-len)
block_idx_last_scheduled_token = block_idx_last_scheduled_token.clamp(min=0)
return (
block_idx_last_computed_token,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
)