2026-04-18 10:56:22 +08:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
import itertools
|
|
|
|
|
from dataclasses import dataclass, replace
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from vllm.config import VllmConfig
|
|
|
|
|
from vllm.v1.attention.backend import (
|
|
|
|
|
AttentionBackend,
|
|
|
|
|
CommonAttentionMetadata,
|
|
|
|
|
)
|
|
|
|
|
from vllm.v1.attention.backends.mamba_attn import (
|
|
|
|
|
BaseMambaAttentionMetadata,
|
|
|
|
|
BaseMambaAttentionMetadataBuilder,
|
|
|
|
|
)
|
|
|
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_varlen_chunk_metadata(
|
|
|
|
|
query_start_loc: torch.Tensor,
|
|
|
|
|
chunk_size: int,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
|
"""
|
|
|
|
|
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
|
|
|
|
|
|
|
|
|
|
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
|
|
|
|
|
and a physical `chunk_size`, returns three tensors on the same device:
|
|
|
|
|
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
|
|
|
|
|
logical-chunk lengths (each logical chunk never crosses a sequence or
|
|
|
|
|
physical-chunk boundary).
|
|
|
|
|
- last_chunk_indices: (B,) int32 index of the last logical chunk
|
|
|
|
|
for each sequence (=-1 for empty sequences).
|
|
|
|
|
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
|
|
|
|
|
chunk in order.
|
|
|
|
|
|
|
|
|
|
This is intentionally lightweight and CPU-side; it mirrors the metadata
|
|
|
|
|
produced by the V1 Mamba2 meta-data builder and is exported so tests
|
|
|
|
|
(and other callers) can avoid duplicating the logic.
|
|
|
|
|
"""
|
|
|
|
|
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
|
|
|
|
|
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
|
|
|
|
|
device = query_start_loc.device
|
|
|
|
|
|
|
|
|
|
qsl64 = query_start_loc.to(torch.int64)
|
|
|
|
|
starts = qsl64[:-1].tolist()
|
|
|
|
|
ends = qsl64[1:].tolist()
|
|
|
|
|
total = int(qsl64[-1].item())
|
|
|
|
|
|
|
|
|
|
chunk_lens: list[int] = []
|
|
|
|
|
seq_idx_chunks: list[int] = []
|
|
|
|
|
last_chunk_indices: list[int] = [-1] * len(starts)
|
|
|
|
|
|
|
|
|
|
for b, (s, e) in enumerate(zip(starts, ends)):
|
|
|
|
|
if e <= s:
|
|
|
|
|
# empty sequence
|
|
|
|
|
continue
|
|
|
|
|
pos = s
|
|
|
|
|
while pos < e:
|
|
|
|
|
# split at both sequence boundaries and physical chunk boundaries
|
|
|
|
|
room = chunk_size - (pos % chunk_size)
|
|
|
|
|
take = min(room, e - pos)
|
|
|
|
|
chunk_lens.append(int(take))
|
|
|
|
|
seq_idx_chunks.append(b)
|
|
|
|
|
last_chunk_indices[b] = len(chunk_lens) - 1
|
|
|
|
|
pos += take
|
|
|
|
|
|
|
|
|
|
# Exclusive prefix sum over logical-chunk lengths
|
|
|
|
|
if chunk_lens:
|
|
|
|
|
cu_chunk_seqlens = torch.tensor(
|
|
|
|
|
[0] + list(itertools.accumulate(chunk_lens)),
|
|
|
|
|
device=device,
|
|
|
|
|
dtype=torch.int32,
|
|
|
|
|
)
|
|
|
|
|
# Final boundary must equal total tokens
|
|
|
|
|
assert int(cu_chunk_seqlens[-1].item()) == total
|
|
|
|
|
else:
|
|
|
|
|
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
|
|
|
|
|
|
|
|
|
|
last_chunk_indices_t = (
|
|
|
|
|
torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
|
|
|
|
|
if len(starts) > 0
|
|
|
|
|
else torch.empty((0,), device=device, dtype=torch.int32)
|
|
|
|
|
)
|
|
|
|
|
seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
|
|
|
|
|
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mamba2AttentionBackend(AttentionBackend):
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_name() -> str:
|
|
|
|
|
return "MAMBA2_ATTN"
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
|
|
|
|
return Mamba2AttentionMetadataBuilder
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class Mamba2AttentionMetadata(BaseMambaAttentionMetadata):
|
|
|
|
|
prep_initial_states: bool = False
|
|
|
|
|
chunk_size: int = 0
|
|
|
|
|
|
|
|
|
|
# Chunk-related metadata (only for prefill)
|
|
|
|
|
seq_idx_p: torch.Tensor | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mamba2AttentionMetadataBuilder(
|
|
|
|
|
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
|
|
|
|
):
|
|
|
|
|
metadata_cls = Mamba2AttentionMetadata
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
kv_cache_spec: AttentionSpec,
|
|
|
|
|
layer_names: list[str],
|
|
|
|
|
vllm_config: VllmConfig,
|
|
|
|
|
device: torch.device,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
|
|
|
chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
|
|
|
|
assert chunk_size is not None, (
|
|
|
|
|
"chunk_size needs to be set in the model config for Mamba2 models"
|
|
|
|
|
)
|
|
|
|
|
self.chunk_size: int = chunk_size
|
|
|
|
|
|
|
|
|
|
def build(
|
|
|
|
|
self,
|
|
|
|
|
common_prefix_len: int,
|
|
|
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
|
|
|
fast_build: bool = False,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> Mamba2AttentionMetadata:
|
|
|
|
|
common = self._compute_common_metadata(
|
|
|
|
|
common_attn_metadata, num_accepted_tokens=kwargs.get("num_accepted_tokens")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
seq_idx_p = None
|
|
|
|
|
cu_chunk_seqlen_p = None
|
|
|
|
|
last_chunk_indices_p = None
|
|
|
|
|
prep_initial_states = False
|
|
|
|
|
|
|
|
|
|
# Compute seq_idx for prefill only
|
|
|
|
|
if common.num_prefills > 0:
|
|
|
|
|
prep_initial_states = (
|
|
|
|
|
torch.any(common.has_initial_states_p).item()
|
|
|
|
|
if common.has_initial_states_p is not None
|
|
|
|
|
else False
|
|
|
|
|
)
|
|
|
|
|
|
2026-04-29 19:38:22 +08:00
|
|
|
cu_chunk_seqlen_p, seq_idx_p, last_chunk_indices_p = (
|
|
|
|
|
self._build_chunk_metadata_tensors(
|
|
|
|
|
self.chunk_size,
|
|
|
|
|
common,
|
|
|
|
|
common_attn_metadata,
|
|
|
|
|
)
|
2026-04-18 10:56:22 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return replace(
|
|
|
|
|
common,
|
|
|
|
|
prep_initial_states=prep_initial_states,
|
|
|
|
|
chunk_size=self.chunk_size,
|
|
|
|
|
seq_idx_p=seq_idx_p,
|
|
|
|
|
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
|
|
|
|
last_chunk_indices_p=last_chunk_indices_p,
|
|
|
|
|
)
|