Sync from v0.13
This commit is contained in:
159
vllm/v1/attention/backends/mamba1_attn.py
Normal file
159
vllm/v1/attention/backends/mamba1_attn.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
|
||||
return Mamba1AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata:
|
||||
query_start_loc_p: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
|
||||
# We should consolidate this code
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
if num_prefills > 0:
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_decode_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:num_decode_tokens
|
||||
]
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
Reference in New Issue
Block a user