53 lines
1.9 KiB
Python
53 lines
1.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import abc
|
|
from typing import ClassVar, TypeVar
|
|
|
|
import torch
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.attention.backends.utils import (AttentionCGSupport,
|
|
AttentionMetadataBuilder,
|
|
CommonAttentionMetadata)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
|
|
|
M = TypeVar("M")
|
|
|
|
|
|
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
|
reorder_batch_threshold: int = 1
|
|
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
|
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
|
|
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
|
vllm_config: VllmConfig, device: torch.device):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
|
|
assert isinstance(kv_cache_spec, MambaSpec)
|
|
self.compilation_config = vllm_config.compilation_config
|
|
self.decode_cudagraph_max_bs = min(
|
|
self.vllm_config.scheduler_config.max_num_seqs,
|
|
self.compilation_config.max_capture_size)
|
|
self.state_indices_tensor = torch.empty(
|
|
(self.decode_cudagraph_max_bs, ),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
def build_for_cudagraph_capture(
|
|
self, common_attn_metadata: CommonAttentionMetadata) -> M:
|
|
"""
|
|
This method builds the metadata for full cudagraph capture.
|
|
Currently, only decode is supported for full cudagraphs with Mamba.
|
|
"""
|
|
m = common_attn_metadata
|
|
|
|
assert m.num_reqs == m.num_actual_tokens, \
|
|
"Mamba only supports decode-only full CUDAGraph capture. " \
|
|
"Make sure all cudagraph capture sizes <= max_num_seq."
|
|
|
|
m.max_query_len = 1 # decode-only
|
|
|
|
return self.build(0, m)
|