update
This commit is contained in:
464
vllm/v1/attention/backends/mamba_attn.py
Normal file
464
vllm/v1/attention/backends/mamba_attn.py
Normal file
@@ -0,0 +1,464 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Any, ClassVar, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
compute_causal_conv1d_metadata,
|
||||
mamba_get_block_table_tensor,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
M = TypeVar("M", bound="BaseMambaAttentionMetadata")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseMambaAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_reqs: int
|
||||
|
||||
# The following tensors only contain prefill requests and will be None if
|
||||
# the batch has no prefill requests.
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
query_start_loc_p: torch.Tensor | None
|
||||
num_computed_tokens_p: torch.Tensor | None
|
||||
state_indices_tensor_p: torch.Tensor | None
|
||||
|
||||
# The following tensors are used for decode requests and
|
||||
# speculative decoding compatibility, and will be None if the batch
|
||||
# has no decode requests.
|
||||
state_indices_tensor_d: torch.Tensor | None
|
||||
query_start_loc_d: torch.Tensor | None # shape: [num_decodes + 1,]
|
||||
|
||||
# Number of accepted tokens for each spec sequence (for loading correct checkpoint)
|
||||
# Includes the bonus token (so minimum is 1)
|
||||
num_accepted_tokens: torch.Tensor | None # shape: [batch,]
|
||||
|
||||
# The following tensors are only used for prefix caching in all mode and
|
||||
# are None if disabled
|
||||
block_idx_last_scheduled_token: torch.Tensor | None
|
||||
block_idx_first_scheduled_token_p: torch.Tensor | None
|
||||
block_idx_last_computed_token: torch.Tensor | None
|
||||
|
||||
# The following tensor is only used for prefix caching in align mode
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
metadata_cls: type[M]
|
||||
reorder_batch_threshold: int = 1
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
# Will be disabled if speculative decoding is used
|
||||
supports_update_block_table: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
# Enable speculative decoding support
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.num_spec_tokens: int = vllm_config.num_speculative_tokens
|
||||
self.use_spec_decode = self.num_spec_tokens > 0
|
||||
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = self.vllm_config.scheduler_config.max_num_seqs
|
||||
if self.compilation_config.max_cudagraph_capture_size is not None:
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.decode_cudagraph_max_bs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
max_num_blocks = cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size,
|
||||
)
|
||||
# Speculative decoding not supported with prefix caching,
|
||||
# so keep shape consistent with prefill buffer
|
||||
# TODO: reduce this size as needed for decode-only cudagraph capture
|
||||
self.state_indices_tensor_d = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
max_num_blocks,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_scheduled_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_computed_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.state_indices_tensor_d = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, 1 + self.num_spec_tokens),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# For speculative decoding, we need to store the following buffers
|
||||
# for CUDA graph capture during decode
|
||||
if self.num_spec_tokens > 0:
|
||||
self.decode_num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
if self.use_spec_decode:
|
||||
self.supports_update_block_table = False
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (
|
||||
m.max_query_len <= 1 + self.num_spec_tokens
|
||||
and m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
), (
|
||||
"Mamba only supports decode-only full CUDAGraph capture. "
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
)
|
||||
|
||||
assert m.max_query_len == 1 + self.num_spec_tokens # decode-only
|
||||
|
||||
num_accepted_tokens = None
|
||||
if self.num_spec_tokens > 0:
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
|
||||
return self.build(0, m, num_accepted_tokens=num_accepted_tokens)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
*,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
**kwargs: Any,
|
||||
) -> M:
|
||||
"""
|
||||
Default build implementation for Mamba-like attention backends.
|
||||
Subclasses (e.g., Mamba2) can override to add additional metadata.
|
||||
"""
|
||||
return self._compute_common_metadata(
|
||||
common_attn_metadata, num_accepted_tokens=num_accepted_tokens
|
||||
)
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
mamba_block_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
# Block index of the last computed token
|
||||
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||
# which is <= block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = (
|
||||
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
|
||||
)
|
||||
# which is <= block index of the last scheduled token
|
||||
block_idx_last_scheduled_token = (
|
||||
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = torch.clamp(
|
||||
block_idx_last_computed_token, min=0
|
||||
)
|
||||
# -1 in the case we have a padded request (0 seq-len)
|
||||
block_idx_last_scheduled_token = torch.clamp(
|
||||
block_idx_last_scheduled_token, min=0
|
||||
)
|
||||
|
||||
return (
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
)
|
||||
|
||||
def _compute_common_metadata(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
*,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
) -> M:
|
||||
"""
|
||||
Compute metadata common to both Mamba1 and Mamba2.
|
||||
"""
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
# Treat multi-token queries as decode requests when
|
||||
# speculative decoding is enabled. Otherwise, use the
|
||||
# default decode threshold to prevent misclassification
|
||||
# of prefill queries as decode requests.
|
||||
decode_threshold = (
|
||||
self.reorder_batch_threshold if num_accepted_tokens is not None else 1
|
||||
)
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=decode_threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Need flags to indicate if there are initial states
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
query_start_loc_d = None
|
||||
num_computed_tokens = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
# for prefix caching
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
block_idx_last_computed_token = None
|
||||
block_idx_last_scheduled_token = None
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
common_attn_metadata.block_table_tensor,
|
||||
common_attn_metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
|
||||
if state_indices_tensor.dim() == 1:
|
||||
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
|
||||
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
if self.vllm_config.cache_config.mamba_cache_mode != "all":
|
||||
state_indices_tensor_d = state_indices_tensor_d[
|
||||
:, : 1 + self.num_spec_tokens
|
||||
]
|
||||
state_indices_tensor_p = state_indices_tensor_p[:, 0]
|
||||
|
||||
if num_decodes > 0 and self.use_spec_decode:
|
||||
assert num_accepted_tokens is not None
|
||||
query_start_loc_d = common_attn_metadata.query_start_loc[: num_decodes + 1]
|
||||
num_accepted_tokens = num_accepted_tokens[:num_decodes]
|
||||
|
||||
if num_prefills > 0:
|
||||
if num_computed_tokens is None:
|
||||
num_computed_tokens = common_attn_metadata.compute_num_computed_tokens()
|
||||
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_p = (
|
||||
num_computed_tokens[num_reqs - num_prefills : num_reqs] > 0
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(
|
||||
query_start_loc_p_cpu,
|
||||
device=common_attn_metadata.query_start_loc.device,
|
||||
)
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
|
||||
metadata = self.metadata_cls(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
query_start_loc_d=query_start_loc_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
num_reqs=num_reqs,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
|
||||
return self._update_metadata_for_cudagraph_capture(metadata)
|
||||
|
||||
def _update_metadata_for_cudagraph_capture(
|
||||
self,
|
||||
metadata: M,
|
||||
) -> M:
|
||||
"""
|
||||
Update the metadata for cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
state_indices_tensor_d = metadata.state_indices_tensor_d
|
||||
query_start_loc_d = metadata.query_start_loc_d
|
||||
num_accepted_tokens = metadata.num_accepted_tokens
|
||||
block_idx_last_scheduled_token = metadata.block_idx_last_scheduled_token
|
||||
block_idx_last_computed_token = metadata.block_idx_last_computed_token
|
||||
if (
|
||||
metadata.num_prefills == 0
|
||||
and metadata.num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
padded_bs = metadata.num_reqs
|
||||
self.state_indices_tensor_d[: metadata.num_decodes].copy_(
|
||||
state_indices_tensor_d, non_blocking=True
|
||||
)
|
||||
state_indices_tensor_d = self.state_indices_tensor_d[:padded_bs]
|
||||
state_indices_tensor_d[metadata.num_decodes :] = PAD_SLOT_ID
|
||||
|
||||
if self.use_spec_decode:
|
||||
assert query_start_loc_d is not None
|
||||
assert num_accepted_tokens is not None
|
||||
query_start_loc_d = query_start_loc_d[: padded_bs + 1]
|
||||
self.decode_num_accepted_tokens[: metadata.num_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.decode_num_accepted_tokens[:padded_bs]
|
||||
num_accepted_tokens[metadata.num_decodes :] = (
|
||||
1 # pad with 1st slot index
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.mamba_cache_mode == "all":
|
||||
assert block_idx_last_scheduled_token is not None
|
||||
assert block_idx_last_computed_token is not None
|
||||
self.block_idx_last_scheduled_token[: metadata.num_decodes].copy_(
|
||||
block_idx_last_scheduled_token[: metadata.num_decodes],
|
||||
non_blocking=True,
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
: metadata.num_decode_tokens
|
||||
]
|
||||
|
||||
self.block_idx_last_computed_token[: metadata.num_decodes].copy_(
|
||||
block_idx_last_computed_token[: metadata.num_decodes],
|
||||
non_blocking=True,
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
: metadata.num_decode_tokens
|
||||
]
|
||||
|
||||
return replace(
|
||||
metadata,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
query_start_loc_d=query_start_loc_d,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
)
|
||||
|
||||
def update_block_table(
|
||||
self,
|
||||
metadata: M,
|
||||
blk_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> M:
|
||||
state_indices_tensor = mamba_get_block_table_tensor(
|
||||
blk_table,
|
||||
metadata.seq_lens,
|
||||
self.kv_cache_spec,
|
||||
self.vllm_config.cache_config.mamba_cache_mode,
|
||||
)
|
||||
if state_indices_tensor.dim() == 1:
|
||||
state_indices_tensor = state_indices_tensor.unsqueeze(-1)
|
||||
|
||||
assert (
|
||||
metadata.num_prefills + metadata.num_decodes
|
||||
== state_indices_tensor.shape[0]
|
||||
), (
|
||||
"Mismatch in number of requests when updating block table."
|
||||
f" Expected {metadata.num_prefills + metadata.num_decodes}, "
|
||||
f"got {state_indices_tensor.shape[0]}."
|
||||
)
|
||||
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[metadata.num_decodes, metadata.num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
if self.vllm_config.cache_config.mamba_cache_mode != "all":
|
||||
state_indices_tensor_d = state_indices_tensor_d[
|
||||
:, : 1 + self.num_spec_tokens
|
||||
]
|
||||
state_indices_tensor_p = state_indices_tensor_p[:, 0]
|
||||
|
||||
new_metadata = replace(
|
||||
metadata,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
)
|
||||
|
||||
return self._update_metadata_for_cudagraph_capture(new_metadata)
|
||||
Reference in New Issue
Block a user