95 lines
3.4 KiB
Python
95 lines
3.4 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
from vllm.attention.backends.abstract import AttentionBackend
|
|
from vllm.config import VllmConfig
|
|
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
|
CommonAttentionMetadata,
|
|
compute_causal_conv1d_metadata,
|
|
split_decodes_and_prefills)
|
|
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
|
|
|
|
|
class ShortConvAttentionBackend(AttentionBackend):
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
|
|
return ShortConvAttentionMetadataBuilder
|
|
|
|
|
|
@dataclass
|
|
class ShortConvAttentionMetadata:
|
|
num_prefills: int
|
|
num_prefill_tokens: int
|
|
num_decodes: int
|
|
num_decode_tokens: int
|
|
|
|
query_start_loc: torch.Tensor
|
|
has_initial_states: torch.Tensor
|
|
state_indices_tensor: torch.Tensor # shape: [batch,]
|
|
|
|
# For causal_conv1d
|
|
nums_dict: Optional[dict] = None
|
|
batch_ptr: Optional[torch.Tensor] = None
|
|
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
|
|
|
|
|
class ShortConvAttentionMetadataBuilder(
|
|
AttentionMetadataBuilder[ShortConvAttentionMetadata]):
|
|
|
|
reorder_batch_threshold: int = 1
|
|
|
|
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
|
vllm_config: VllmConfig, device: torch.device):
|
|
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
|
assert isinstance(kv_cache_spec, MambaSpec)
|
|
|
|
def build(self,
|
|
common_prefix_len: int,
|
|
common_attn_metadata: CommonAttentionMetadata,
|
|
fast_build: bool = False) -> ShortConvAttentionMetadata:
|
|
num_reqs = common_attn_metadata.num_reqs
|
|
query_start_loc = common_attn_metadata.query_start_loc
|
|
|
|
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
|
|
|
# for causal_conv1d
|
|
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
|
|
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
|
split_decodes_and_prefills(
|
|
common_attn_metadata,
|
|
decode_threshold=self.reorder_batch_threshold))
|
|
has_initial_states = None
|
|
if num_prefills > 0:
|
|
#[batch,]
|
|
has_initial_states_cpu = (
|
|
common_attn_metadata.
|
|
num_computed_tokens_cpu[num_reqs - num_prefills:num_reqs] > 0)
|
|
has_initial_states = has_initial_states_cpu.to(
|
|
query_start_loc.device)
|
|
|
|
query_start_loc_p = common_attn_metadata.query_start_loc[
|
|
-num_prefills - 1:] - num_decode_tokens
|
|
|
|
nums_dict, batch_ptr, token_chunk_offset_ptr = \
|
|
compute_causal_conv1d_metadata(query_start_loc_p)
|
|
|
|
attn_metadata = ShortConvAttentionMetadata(
|
|
num_prefills=num_prefills,
|
|
num_prefill_tokens=num_prefill_tokens,
|
|
num_decodes=num_decodes,
|
|
num_decode_tokens=num_decode_tokens,
|
|
query_start_loc=query_start_loc,
|
|
has_initial_states=has_initial_states,
|
|
state_indices_tensor=state_indices_tensor,
|
|
nums_dict=nums_dict,
|
|
batch_ptr=batch_ptr,
|
|
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
|
)
|
|
return attn_metadata
|