v1.0
This commit is contained in:
0
v1/__init__.py
Normal file
0
v1/__init__.py
Normal file
BIN
v1/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/__pycache__/cudagraph_dispatcher.cpython-312.pyc
Normal file
BIN
v1/__pycache__/cudagraph_dispatcher.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/__pycache__/kv_cache_interface.cpython-312.pyc
Normal file
BIN
v1/__pycache__/kv_cache_interface.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/__pycache__/outputs.cpython-312.pyc
Normal file
BIN
v1/__pycache__/outputs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/__pycache__/request.cpython-312.pyc
Normal file
BIN
v1/__pycache__/request.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/__pycache__/serial_utils.cpython-312.pyc
Normal file
BIN
v1/__pycache__/serial_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/__pycache__/utils.cpython-312.pyc
Normal file
BIN
v1/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
0
v1/attention/__init__.py
Normal file
0
v1/attention/__init__.py
Normal file
BIN
v1/attention/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/attention/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
0
v1/attention/backends/__init__.py
Normal file
0
v1/attention/backends/__init__.py
Normal file
BIN
v1/attention/backends/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/cpu_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/cpu_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/flash_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/flash_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/flashinfer.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/flashinfer.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/flex_attention.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/flex_attention.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/gdn_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/gdn_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/linear_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/linear_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/mamba1_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/mamba1_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/mamba2_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/mamba2_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/mamba_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/mamba_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/pallas.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/pallas.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/rocm_aiter_fa.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/rocm_aiter_fa.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
v1/attention/backends/__pycache__/rocm_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/rocm_attn.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
v1/attention/backends/__pycache__/tree_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/tree_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/triton_attn.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/triton_attn.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/utils.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/__pycache__/xformers.cpython-312.pyc
Normal file
BIN
v1/attention/backends/__pycache__/xformers.cpython-312.pyc
Normal file
Binary file not shown.
496
v1/attention/backends/cpu_attn.py
Normal file
496
v1/attention/backends/cpu_attn.py
Normal file
@@ -0,0 +1,496 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_CPU_ARCH_PREFER_MIXED_BATCH = (CpuArchEnum.X86,)
|
||||
|
||||
|
||||
class CPUAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_supported_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CPU_ATTN"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""CPU attention supports decoder and encoder-only attention."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (
|
||||
AttentionType.DECODER,
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CPUAttentionBackendImpl"]:
|
||||
return CPUAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CPUAttentionMetadataBuilder"]:
|
||||
return CPUAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return 2, num_blocks, num_kv_heads, block_size, head_size
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUAttentionMetadata:
|
||||
isa: str
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
scheduler_metadata: torch.Tensor | None
|
||||
causal: bool = True
|
||||
|
||||
# can be removed after deprecate sdpa
|
||||
use_sdpa_prefill: bool = False
|
||||
num_decode_tokens: int = 0
|
||||
sdpa_attn_masks: list[torch.Tensor | None] | None = None
|
||||
sdpa_start_loc: torch.Tensor | None = None
|
||||
|
||||
|
||||
class CPUAttentionMetadataBuilder(AttentionMetadataBuilder[CPUAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.use_sdpa_prefill = False
|
||||
reorder_batch_threshold = None
|
||||
if current_platform.get_cpu_architecture() not in _CPU_ARCH_PREFER_MIXED_BATCH:
|
||||
# in this case, decode seqs are reordered to the front of prefill seqs
|
||||
# to split decode and prefill. Then use SDPA for prefill and
|
||||
# cpu_attention_with_kv_cache for decode
|
||||
reorder_batch_threshold = 1
|
||||
self.use_sdpa_prefill = True
|
||||
|
||||
self._init_reorder_batch_threshold(reorder_batch_threshold, False)
|
||||
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.num_kv_heads = vllm_config.model_config.get_num_kv_heads(parallel_config)
|
||||
self.num_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
parallel_config
|
||||
)
|
||||
self.head_dim = kv_cache_spec.head_size
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.window_size = getattr(kv_cache_spec, "sliding_window", -1)
|
||||
if self.window_size is None:
|
||||
self.window_size = -1
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.isa = _get_attn_isa(self.dtype, self.block_size)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> CPUAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
causal = common_attn_metadata.causal
|
||||
|
||||
sdpa_start_loc = query_start_loc
|
||||
num_decode_tokens = 0
|
||||
if self.use_sdpa_prefill and causal:
|
||||
# Decoder, need reorder and truncate
|
||||
assert self.reorder_batch_threshold
|
||||
(num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
require_uniform=True,
|
||||
)
|
||||
)
|
||||
num_reqs = num_decodes
|
||||
sdpa_start_loc = sdpa_start_loc[num_decodes:] - num_decode_tokens
|
||||
seq_lens = seq_lens[:num_decodes]
|
||||
query_start_loc = query_start_loc[: num_decodes + 1]
|
||||
block_table_tensor = block_table_tensor[:num_decodes]
|
||||
|
||||
sheduler_metadata = None
|
||||
if causal:
|
||||
# for decode batch, use the custom kernel
|
||||
sheduler_metadata = ops.cpu_attn_get_scheduler_metadata(
|
||||
num_reqs=num_reqs,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_dim=self.head_dim,
|
||||
seq_lens=seq_lens,
|
||||
dtype=self.dtype,
|
||||
query_start_loc=query_start_loc,
|
||||
causal=causal,
|
||||
sliding_window_size=self.window_size,
|
||||
isa=self.isa,
|
||||
enable_kv_split=True,
|
||||
)
|
||||
|
||||
attn_metadata = CPUAttentionMetadata(
|
||||
isa=self.isa,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
scheduler_metadata=sheduler_metadata,
|
||||
causal=causal,
|
||||
use_sdpa_prefill=self.use_sdpa_prefill,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
sdpa_start_loc=sdpa_start_loc,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class CPUAttentionBackendImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
if logits_soft_cap is not None and attn_type in (
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
):
|
||||
logger.warning_once(
|
||||
"CPU_ATTN does not support logits softcap for"
|
||||
" ENCODER and ENCODER_ONLY, outputs may be slightly off"
|
||||
)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("FP8 KV cache is unsupported in CPU_ATTN")
|
||||
self.attn_type = attn_type
|
||||
|
||||
self.sinks = sinks
|
||||
if self.sinks is not None:
|
||||
assert self.sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
"heads in the layer"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: CPUAttentionMetadata | None,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass for CPU attention backend.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, num_kv_heads, block_size, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for CPUAttentionBackendImpl"
|
||||
)
|
||||
|
||||
# For warming-up
|
||||
if attn_metadata is None:
|
||||
return output
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
return self._run_sdpa_forward(
|
||||
query[:num_actual_tokens],
|
||||
key[:num_actual_tokens],
|
||||
value[:num_actual_tokens],
|
||||
output[:num_actual_tokens],
|
||||
attn_metadata,
|
||||
self.attn_type,
|
||||
)
|
||||
|
||||
# For decoder and cross-attention, use KV cache, size are
|
||||
# [num_blocks, num_kv_heads, block_size, head_size]
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (
|
||||
self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
ops.cpu_attn_reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
attn_metadata.isa,
|
||||
)
|
||||
|
||||
if attn_metadata.use_sdpa_prefill:
|
||||
assert self.sinks is None, "Attention sink is unsupported in SDPA prefill"
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
self._run_sdpa_forward(
|
||||
query[num_decode_tokens:num_actual_tokens],
|
||||
key[num_decode_tokens:num_actual_tokens],
|
||||
value[num_decode_tokens:num_actual_tokens],
|
||||
output[num_decode_tokens:num_actual_tokens],
|
||||
attn_metadata,
|
||||
self.attn_type,
|
||||
)
|
||||
num_actual_tokens = num_decode_tokens
|
||||
|
||||
if num_actual_tokens > 0:
|
||||
ops.cpu_attention_with_kv_cache(
|
||||
query=query[:num_actual_tokens],
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
output=output[:num_actual_tokens], # type: ignore
|
||||
query_start_loc=attn_metadata.query_start_loc,
|
||||
seq_lens=attn_metadata.seq_lens,
|
||||
scale=self.scale,
|
||||
causal=attn_metadata.causal,
|
||||
alibi_slopes=self.alibi_slopes, # type: ignore
|
||||
sliding_window=self.sliding_window,
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
scheduler_metadata=attn_metadata.scheduler_metadata,
|
||||
s_aux=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _run_sdpa_forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
attn_metadata: CPUAttentionMetadata,
|
||||
attn_type: str,
|
||||
) -> torch.Tensor:
|
||||
attn_masks = attn_metadata.sdpa_attn_masks
|
||||
if attn_masks is None:
|
||||
if self.alibi_slopes is not None:
|
||||
attn_masks = _make_alibi_bias(
|
||||
self.alibi_slopes,
|
||||
query.dtype,
|
||||
attn_metadata.sdpa_start_loc,
|
||||
)
|
||||
elif self.sliding_window[0] != -1 or self.sliding_window[1] != -1:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
attn_masks = _make_sliding_window_bias(
|
||||
attn_metadata.sdpa_start_loc,
|
||||
self.sliding_window[0],
|
||||
self.sliding_window[1],
|
||||
query.dtype,
|
||||
)
|
||||
else:
|
||||
attn_masks = [None] * (attn_metadata.sdpa_start_loc.size(0) - 1) # type: ignore
|
||||
attn_metadata.sdpa_attn_masks = attn_masks
|
||||
|
||||
query = query.movedim(0, query.dim() - 2)
|
||||
key = key.movedim(0, key.dim() - 2)
|
||||
value = value.movedim(0, value.dim() - 2)
|
||||
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
|
||||
|
||||
causal_attn = attn_type == AttentionType.DECODER
|
||||
|
||||
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(len(attn_masks)):
|
||||
mask = attn_masks[i]
|
||||
start_q = sdpa_start_loc[i]
|
||||
end_q = sdpa_start_loc[i + 1]
|
||||
sub_out = (
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
query[None, :, start_q:end_q, :],
|
||||
key[None, :, start_q:end_q, :],
|
||||
value[None, :, start_q:end_q, :],
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=causal_attn and mask is None,
|
||||
scale=self.scale,
|
||||
)
|
||||
.squeeze(0)
|
||||
.movedim(query.dim() - 2, 0)
|
||||
)
|
||||
output[start_q:end_q, :, :] = sub_out
|
||||
return output
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
sdpa_start_loc: torch.Tensor,
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
seq_num = sdpa_start_loc.size(0) - 1
|
||||
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(seq_num):
|
||||
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
|
||||
bias = torch.arange(seq_len, dtype=dtype) # type: ignore
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
|
||||
inf_mask = (
|
||||
torch.empty((1, seq_len, seq_len), dtype=bias.dtype) # type: ignore
|
||||
.fill_(-torch.inf)
|
||||
.triu_(diagonal=1)
|
||||
)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
sdpa_start_loc: torch.Tensor,
|
||||
left_window_size: int,
|
||||
right_window_size: int,
|
||||
dtype: torch.dtype,
|
||||
) -> list[torch.Tensor]:
|
||||
attn_biases: list[torch.Tensor] = []
|
||||
seq_num = sdpa_start_loc.size(0) - 1
|
||||
sdpa_start_loc = sdpa_start_loc.numpy() # type: ignore
|
||||
for i in range(seq_num):
|
||||
seq_len = sdpa_start_loc[i + 1] - sdpa_start_loc[i]
|
||||
mask = torch.full( # type: ignore
|
||||
(1, seq_len, seq_len), # type: ignore
|
||||
fill_value=1,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if right_window_size != -1:
|
||||
mask = torch.tril(mask, diagonal=right_window_size)
|
||||
if left_window_size != -1:
|
||||
mask = torch.triu(mask, diagonal=-left_window_size)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask)
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
|
||||
return "amx"
|
||||
elif block_size % 32 == 0:
|
||||
return "vec"
|
||||
else:
|
||||
return "vec16"
|
||||
1215
v1/attention/backends/flash_attn.py
Normal file
1215
v1/attention/backends/flash_attn.py
Normal file
File diff suppressed because it is too large
Load Diff
1572
v1/attention/backends/flashinfer.py
Normal file
1572
v1/attention/backends/flashinfer.py
Normal file
File diff suppressed because it is too large
Load Diff
926
v1/attention/backends/flex_attention.py
Normal file
926
v1/attention/backends/flex_attention.py
Normal file
@@ -0,0 +1,926 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with FlexAttention."""
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
import torch._dynamo.decorators
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.attention.flex_attention import (
|
||||
BlockMask,
|
||||
_mask_mod_signature,
|
||||
_score_mod_signature,
|
||||
and_masks,
|
||||
create_block_mask,
|
||||
flex_attention,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
create_block_mask_compiled = torch.compile(
|
||||
create_block_mask, fullgraph=True, mode="reduce-overhead"
|
||||
)
|
||||
flex_attention_compiled = torch.compile(flex_attention, fullgraph=True)
|
||||
|
||||
|
||||
def _offsets_to_doc_ids_tensor(offsets: torch.Tensor) -> torch.Tensor:
|
||||
device = offsets.device
|
||||
counts = offsets[1:] - offsets[:-1]
|
||||
return torch.repeat_interleave(
|
||||
torch.arange(len(counts), device=device, dtype=torch.int32), counts
|
||||
)
|
||||
|
||||
|
||||
def pad_to_multiple(x: torch.Tensor, multiple: int, dim: int):
|
||||
difference = (multiple - (x.shape[dim] % multiple)) % multiple
|
||||
if difference == 0:
|
||||
return x
|
||||
|
||||
dim = dim if dim >= 0 else x.ndim + dim
|
||||
pad_list = []
|
||||
|
||||
for i in range(x.ndim - 1, dim - 1, -1):
|
||||
if i == dim:
|
||||
pad_list.extend([0, difference])
|
||||
else:
|
||||
pad_list.extend([0, 0])
|
||||
|
||||
return F.pad(x, pad_list, mode="constant", value=0)
|
||||
|
||||
|
||||
class FlexAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLEX_ATTENTION"
|
||||
|
||||
@classmethod
|
||||
def supports_attn_type(cls, attn_type: str) -> bool:
|
||||
"""FlexAttention supports both decoder and encoder-only attention."""
|
||||
from vllm.attention import AttentionType
|
||||
|
||||
return attn_type in (AttentionType.DECODER, AttentionType.ENCODER_ONLY)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlexAttentionImpl"]:
|
||||
return FlexAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlexAttentionMetadataBuilder"]:
|
||||
return FlexAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return []
|
||||
|
||||
|
||||
# @torch.compile(fullgraph=True, mode="reduce-overhead")
|
||||
def physical_to_logical_mapping(
|
||||
block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
block_size: int,
|
||||
total_blocks: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Creates an inverse mapping from physical block locations to logical indices.
|
||||
|
||||
The original block_table maps from logical blocks to physical locations:
|
||||
|
||||
Logical to Physical (Original block_table):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Logical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Physical Blocks: 3 5 1 7 4 2 0 6 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
This function creates the inverse mapping:
|
||||
|
||||
Physical to Logical (Inverse mapping):
|
||||
┌───────────────────────────────────────────┐
|
||||
│ Request 0: │
|
||||
│ │
|
||||
│ Physical Blocks: 0 1 2 3 4 5 6 7 │
|
||||
│ │ │ │ │ │ │ │ │ │
|
||||
│ v v v v v v v v │
|
||||
│ Logical Blocks: 6 2 5 0 4 1 7 3 │
|
||||
└───────────────────────────────────────────┘
|
||||
|
||||
If multiple logical blocks map to the same physical block,
|
||||
this function returns the first (minimum) logical block index.
|
||||
|
||||
If a physical block is not mapped to by any logical block,
|
||||
its value in the result will be -1.
|
||||
|
||||
IMPORTANT: Garbage Value Protection
|
||||
────────────────────────────────────
|
||||
The block_table tensor may contain garbage values in unused positions
|
||||
(beyond the actual sequence length). For example, if a sequence only
|
||||
needs 3 blocks but the table has space for 8:
|
||||
|
||||
block_table[0] = [10, 25, 7, 999, 1234, 888, ...]
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
garbage values
|
||||
|
||||
These garbage values can cause issues because:
|
||||
1. They may map to valid physical blocks by coincidence
|
||||
2. The scatter_ operation will assign them logical indices
|
||||
3. Later attention computations may incorrectly access these blocks
|
||||
|
||||
To prevent this, we use seq_lens and block_size to mask out unused
|
||||
entries, ensuring only valid block references are processed.
|
||||
|
||||
Args:
|
||||
block_table: Tensor of shape [max_reqs, max_num_blocks]
|
||||
mapping logical blocks to physical locations. May contain
|
||||
garbage values in unused positions.
|
||||
seq_lens: Tensor of sequence lengths for each request. Used to
|
||||
determine how many blocks are actually needed per sequence.
|
||||
block_size: Size of each block in tokens. Used with seq_lens to
|
||||
compute the number of valid blocks per sequence.
|
||||
total_blocks: Total number of physical blocks available
|
||||
|
||||
Returns:
|
||||
A tensor of shape [max_reqs, total_blocks] where each entry
|
||||
physical_to_logical[req_id, physical_block] contains the logical
|
||||
block index for that physical block, or -1 if unused.
|
||||
"""
|
||||
max_reqs, max_num_blocks = block_table.shape
|
||||
device = block_table.device
|
||||
|
||||
physical_to_logical = torch.full(
|
||||
(max_reqs, total_blocks), -1, dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# Only process valid blocks to avoid garbage values
|
||||
num_blocks_per_seq = cdiv(seq_lens, block_size)
|
||||
mask = (
|
||||
torch.arange(max_num_blocks, device=device)[None, :]
|
||||
< num_blocks_per_seq[:, None]
|
||||
)
|
||||
|
||||
valid_block_table = torch.where(mask, block_table, 0)
|
||||
valid_logical_indices = torch.where(
|
||||
mask, torch.arange(max_num_blocks, device=device)[None, :], 0
|
||||
)
|
||||
|
||||
physical_to_logical.scatter_(
|
||||
-1, valid_block_table.to(torch.int64), valid_logical_indices
|
||||
)
|
||||
# NB - Seems like block 0 is always empty so we reset it manually
|
||||
physical_to_logical[:, 0] = -1
|
||||
return physical_to_logical
|
||||
|
||||
|
||||
def unique_static_unsorted(
|
||||
x: torch.Tensor,
|
||||
*,
|
||||
M: int, # maximum positive value (0 is “skip me”)
|
||||
dim: int = -1, # axis along which to deduplicate
|
||||
ignored_val: int = 0, # value to ignore
|
||||
pad_val: int = -1, # sentinel for unused slots
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
- Keeps the first occurrence of each non-zero value while preserving order,
|
||||
then left-packs those uniques and fills the rest with `pad_val`.
|
||||
- Returns (packed, keep_mask) with the *same shape* as `x`.
|
||||
- Requires that all values be in the range [0, M]
|
||||
- Skips ignored_val
|
||||
|
||||
Works on CPU or GPU, no Python loops, O(B·N) time / O(B·M) memory.
|
||||
|
||||
Example:
|
||||
x =[3, 1, 0, 1, 2], M=3, ignored_val=0 => [3, 1, 2, -1, -1]
|
||||
"""
|
||||
if not (-1 <= pad_val <= M):
|
||||
raise ValueError("`pad_val` must lie in [-1, M]")
|
||||
|
||||
# ── move `dim` to the end so we can treat tensor as [B, N] ──────────
|
||||
dim = dim % x.ndim
|
||||
x_perm = x.movedim(dim, -1) # shape [..., N]
|
||||
B, N = x_perm.numel() // x_perm.shape[-1], x_perm.shape[-1]
|
||||
x_flat = x_perm.reshape(B, N) # [B, N]
|
||||
|
||||
device = x.device
|
||||
idx = torch.arange(N, device=device).expand(B, N) # per-row indices
|
||||
|
||||
# ── build first-occurrence table for every v ∈ [0, M] ───────────────
|
||||
first_idx = torch.full((B, M + 1), N, device=device) # “∞”
|
||||
# scatter_reduce_: first_idx[b, v] = min(first_idx[b, v], i) for each i
|
||||
first_idx.scatter_reduce_(1, x_flat, idx, reduce="amin")
|
||||
|
||||
# ── keep mask: first occurrence *and* value ≠ 0 ─────────────────────
|
||||
keep = (x_flat != ignored_val) & (idx == first_idx.gather(1, x_flat)) # [B, N]
|
||||
|
||||
# ── left-pack uniques into a fresh tensor ───────────────────────────
|
||||
dest_pos = torch.cumsum(keep.to(torch.long), dim=1) - 1 # where to go
|
||||
packed_flat = torch.full_like(x_flat, pad_val)
|
||||
|
||||
rows, src_cols = torch.nonzero(keep, as_tuple=True)
|
||||
packed_flat[rows, dest_pos[rows, src_cols]] = x_flat[rows, src_cols]
|
||||
|
||||
# ── restore original layout ─────────────────────────────────────────
|
||||
packed = packed_flat.reshape(x_perm.shape).movedim(-1, dim)
|
||||
return packed
|
||||
|
||||
|
||||
def causal_mask_mod(
|
||||
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
|
||||
):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlexAttentionMetadata:
|
||||
causal: bool
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Block info
|
||||
total_cache_tokens: int
|
||||
block_size: int
|
||||
max_possible_sequence_length: int
|
||||
num_reqs: int
|
||||
physical_to_logical: torch.Tensor
|
||||
decode_offset: torch.Tensor
|
||||
num_blocks_per_seq: torch.Tensor
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
# Flex Metadata
|
||||
num_blocks = 0
|
||||
block_mask: BlockMask | None = None
|
||||
score_mod: _score_mod_signature | None = None
|
||||
logical_mask_mod: _mask_mod_signature = causal_mask_mod
|
||||
doc_ids: torch.Tensor | None = None
|
||||
direct_build: bool = True
|
||||
q_block_size: int = 16
|
||||
kv_block_size: int = 16
|
||||
transformed_score_mod: _score_mod_signature | None = None
|
||||
sliding_window: int | None = None
|
||||
|
||||
def _convert_physical_to_logical(
|
||||
self,
|
||||
request_lookup: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Convert physical indices to logical indices for both query and kv.
|
||||
|
||||
NB is_within_lower_bound: do sequences start on block_boundaries?
|
||||
|
||||
Returns:
|
||||
tuple of (is_valid, logical_q_idx, logical_kv_idx)
|
||||
"""
|
||||
# Map query indices to corresponding request indices
|
||||
q_req = request_lookup[q_idx]
|
||||
|
||||
# Convert physical KV indices to logical indices
|
||||
physical_kv_block = physical_kv_idx // self.block_size
|
||||
physical_kv_offset = physical_kv_idx % self.block_size
|
||||
logical_block_idx = self.physical_to_logical[q_req, physical_kv_block]
|
||||
logical_kv_idx = logical_block_idx * self.block_size + physical_kv_offset
|
||||
|
||||
# Determine valid kv indices
|
||||
live_block = logical_block_idx >= 0
|
||||
within_upper_bound = logical_kv_idx < self.seq_lens[q_req]
|
||||
within_lower_bound = logical_kv_idx >= 0
|
||||
is_valid = live_block & within_upper_bound & within_lower_bound
|
||||
|
||||
# Convert physical query indices to logical indices
|
||||
local_q_idx = q_idx - self.query_start_loc[q_req]
|
||||
logical_q_idx = local_q_idx + self.decode_offset[q_req]
|
||||
|
||||
return is_valid, logical_q_idx, logical_kv_idx
|
||||
|
||||
def get_causal_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the mask_mod function for FlexAttention.
|
||||
|
||||
This function creates the combined mask mod function that handles:
|
||||
1. The paged attention block mapping
|
||||
2. The mapping from packed query sequences to logical query entries
|
||||
|
||||
It also by defaults adds the decoding offset to the query indices.
|
||||
With this info we create the "logical" indices that are passed to
|
||||
mask_mod functions. This allows mask mod functions to be agnostic to
|
||||
layout of the query and key/value tensors.
|
||||
"""
|
||||
assert self.doc_ids is not None
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx, logical_kv_idx) = (
|
||||
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
|
||||
)
|
||||
# Apply mask modification only for valid indices
|
||||
return torch.where(
|
||||
is_valid,
|
||||
self.logical_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_bidirectional_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the encoder mask_mod function for FlexAttention.
|
||||
|
||||
Since the encoder bidirectional attention doesn't run with
|
||||
KV cache, this function creates a mask based on the
|
||||
packed query sequences.
|
||||
"""
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return request_lookup[q_idx] == request_lookup[kv_idx]
|
||||
|
||||
return final_mask_mod
|
||||
|
||||
def get_sliding_window_mask_mod(self) -> _mask_mod_signature:
|
||||
"""Creates the sliding window mask_mod function for FlexAttention.
|
||||
|
||||
Note that the sliding window mask here is bidirectional, we need
|
||||
to mask it with the bidirectional/causal mask for encoder/decoder.
|
||||
"""
|
||||
|
||||
if self.sliding_window is None:
|
||||
raise ValueError("sliding_window must be set for sliding window attention")
|
||||
|
||||
def sliding_window_mask_mod(
|
||||
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
|
||||
):
|
||||
return torch.abs(q_idx - kv_idx) < self.sliding_window
|
||||
|
||||
def final_mask_mod(
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx, logical_kv_idx) = (
|
||||
self._convert_physical_to_logical(self.doc_ids, q_idx, physical_kv_idx)
|
||||
)
|
||||
return torch.where(
|
||||
is_valid,
|
||||
sliding_window_mask_mod(b, h, logical_q_idx, logical_kv_idx),
|
||||
False,
|
||||
)
|
||||
|
||||
return final_mask_mod if self.causal else sliding_window_mask_mod
|
||||
|
||||
def get_mask_mod(self):
|
||||
# Stage-1: initialize the base mask_mod
|
||||
# (causal mask for decoder or bidirectional mask for encoder)
|
||||
if self.causal:
|
||||
mask_mod = self.get_causal_mask_mod()
|
||||
else:
|
||||
mask_mod = self.get_bidirectional_mask_mod()
|
||||
# stage-2: add external mask_mod for special attention during
|
||||
# forwarding runtime to create the combined mask_mod.
|
||||
if self.sliding_window is not None:
|
||||
# Add sliding window mask for sliding window attention
|
||||
sliding_window_mask_mod = self.get_sliding_window_mask_mod()
|
||||
mask_mod = and_masks(mask_mod, sliding_window_mask_mod)
|
||||
return mask_mod
|
||||
|
||||
def get_transformed_score_mod(self) -> _score_mod_signature | None:
|
||||
"""Creates the transformed score_mod function for FlexAttention.
|
||||
|
||||
This function wraps the user's score_mod to handle physical-to-logical
|
||||
index conversion, similar to how get_mask_mod works for mask functions.
|
||||
"""
|
||||
if self.score_mod is None:
|
||||
return None
|
||||
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
request_lookup = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
user_score_mod = self.score_mod
|
||||
|
||||
def transformed_score_mod(
|
||||
score: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
q_idx: torch.Tensor,
|
||||
physical_kv_idx: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
(is_valid, logical_q_idx, logical_kv_idx) = (
|
||||
self._convert_physical_to_logical(
|
||||
request_lookup, q_idx, physical_kv_idx
|
||||
)
|
||||
)
|
||||
|
||||
return torch.where(
|
||||
is_valid,
|
||||
user_score_mod(
|
||||
score, b, h, logical_q_idx, logical_kv_idx, physical_q=q_idx
|
||||
),
|
||||
-float("inf"),
|
||||
)
|
||||
|
||||
return transformed_score_mod
|
||||
|
||||
def _build_block_mask_direct(self) -> BlockMask:
|
||||
"""Direct block mask construction for standard causal attention.
|
||||
|
||||
This method constructs the block mask directly using
|
||||
BlockMask.from_kv_blocks which is much more efficient than the
|
||||
generic create_block_mask approach.
|
||||
|
||||
The direct path works as follows:
|
||||
1. For each query token, fetch blocks from block_table using max_seq_len
|
||||
(this fetches more blocks than needed for shorter sequences)
|
||||
2. Group query tokens into chunks of q_block_size
|
||||
3. For each group, deduplicate the blocks using unique_static_unsorted
|
||||
4. Create BlockMask using the deduplicated block indices
|
||||
|
||||
Over-estimation occurs when a group of q_block_size tokens contains
|
||||
multiple sequence IDs (doc_ids). In this case, we fetch ALL blocks for
|
||||
each sequence represented in the group, even though individual query
|
||||
tokens may only need a subset of those blocks based on causal masking
|
||||
and their position.
|
||||
|
||||
"""
|
||||
page_to_block_ratio = self.kv_block_size // self.block_size
|
||||
if page_to_block_ratio != 1:
|
||||
raise ValueError(
|
||||
f"FlexAttention currently requires the cache block size "
|
||||
f"({self.block_size}) to be equal to the kv_block_size "
|
||||
f"({self.kv_block_size}). Please check your model's "
|
||||
f"configuration."
|
||||
)
|
||||
|
||||
used_pages = self.block_table[
|
||||
self.doc_ids, : cdiv(self.max_seq_len, self.block_size)
|
||||
]
|
||||
used_pages_padded = pad_to_multiple(
|
||||
used_pages, multiple=self.q_block_size, dim=0
|
||||
)
|
||||
used_pages_padded = used_pages_padded.reshape(
|
||||
used_pages_padded.shape[0] // self.q_block_size, -1
|
||||
)
|
||||
used_pages_padded = used_pages_padded // page_to_block_ratio
|
||||
kv_indices = unique_static_unsorted(
|
||||
(used_pages_padded.long()), M=self.num_blocks
|
||||
).to(torch.int32)
|
||||
|
||||
kv_num_blocks = (kv_indices >= 0).sum(dim=-1).to(torch.int32)
|
||||
block_mask_kwargs = {
|
||||
"seq_lengths": (self.num_actual_tokens, self.total_cache_tokens),
|
||||
"kv_num_blocks": kv_num_blocks[None, None],
|
||||
"kv_indices": kv_indices[None, None],
|
||||
"full_kv_num_blocks": None,
|
||||
"full_kv_indices": None,
|
||||
"BLOCK_SIZE": (self.q_block_size, self.kv_block_size),
|
||||
"mask_mod": self.mask_mod,
|
||||
}
|
||||
|
||||
# compute_q_blocks parameter is available in PyTorch 2.9+
|
||||
if is_torch_equal_or_newer("2.9.0.dev0"):
|
||||
block_mask_kwargs["compute_q_blocks"] = False
|
||||
return BlockMask.from_kv_blocks(**block_mask_kwargs)
|
||||
|
||||
def build_block_mask(self) -> BlockMask:
|
||||
mask_mod = self.get_mask_mod()
|
||||
kv_len = self.total_cache_tokens if self.causal else self.num_actual_tokens
|
||||
return create_block_mask_compiled(
|
||||
mask_mod,
|
||||
None,
|
||||
None,
|
||||
self.num_actual_tokens,
|
||||
kv_len,
|
||||
device=self.block_table.device,
|
||||
BLOCK_SIZE=(self.q_block_size, self.kv_block_size),
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.use_cascade is False, "Not implemented yet."
|
||||
assert self.common_prefix_len == 0, "Not implemented yet."
|
||||
assert self.cu_prefix_query_lens is None, "Not implemented yet."
|
||||
assert self.prefix_kv_lens is None, "Not implemented yet."
|
||||
assert self.suffix_kv_lens is None, "Not implemented yet."
|
||||
# Create a lookup mapping from query indices -> request number
|
||||
self.doc_ids = _offsets_to_doc_ids_tensor(self.query_start_loc)
|
||||
self.num_blocks = self.total_cache_tokens // self.block_size
|
||||
|
||||
self.mask_mod = self.get_mask_mod()
|
||||
self.transformed_score_mod = self.get_transformed_score_mod()
|
||||
|
||||
if self.direct_build and self.causal:
|
||||
self.block_mask = self._build_block_mask_direct()
|
||||
else:
|
||||
self.block_mask = self.build_block_mask()
|
||||
|
||||
|
||||
class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
supports_small_blocks = is_torch_equal_or_newer("2.9.0.dev0")
|
||||
self.direct_build: bool = supports_small_blocks
|
||||
self.q_block_size: int = 16 if supports_small_blocks else 128
|
||||
self.kv_block_size: int = self.block_size if supports_small_blocks else 128
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlexAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
num_blocks_per_seq = cdiv(seq_lens, self.block_size)
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
if use_cascade:
|
||||
raise NotImplementedError("Not yet my friend")
|
||||
|
||||
block_size = self.kv_cache_spec.block_size
|
||||
max_possible_seq_len = self.model_config.max_model_len
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
|
||||
assert num_gpu_blocks is not None, (
|
||||
"FlexAttention requires num_gpu_blocks to be set"
|
||||
)
|
||||
total_cache_tokens = num_gpu_blocks * block_size
|
||||
|
||||
inverse_block_table = physical_to_logical_mapping(
|
||||
block_table_tensor, seq_lens, block_size, num_gpu_blocks
|
||||
)
|
||||
|
||||
offset_tensor = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
|
||||
out = FlexAttentionMetadata(
|
||||
causal=common_attn_metadata.causal,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
block_size=block_size,
|
||||
max_possible_sequence_length=max_possible_seq_len,
|
||||
num_reqs=num_reqs,
|
||||
physical_to_logical=inverse_block_table,
|
||||
total_cache_tokens=total_cache_tokens,
|
||||
decode_offset=offset_tensor,
|
||||
num_blocks_per_seq=num_blocks_per_seq,
|
||||
# FIXME(Isotr0py): direct build has issue to build bidirectional
|
||||
# attention block mask for encoder-only models, disable it temporarily.
|
||||
# see: https://github.com/vllm-project/vllm/pull/27329#issuecomment-3431484053
|
||||
direct_build=(self.direct_build and common_attn_metadata.causal),
|
||||
q_block_size=self.q_block_size,
|
||||
kv_block_size=self.kv_block_size,
|
||||
)
|
||||
return out
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class FlexAttentionImpl(AttentionImpl):
|
||||
sliding_window: int | None
|
||||
alibi_slopes: torch.Tensor | None
|
||||
logits_soft_cap: float | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.attn_type = attn_type
|
||||
|
||||
if attn_type not in (AttentionType.ENCODER_ONLY, AttentionType.DECODER):
|
||||
raise NotImplementedError(
|
||||
f"FlexAttention does not support {attn_type} attention"
|
||||
)
|
||||
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support alibi slopes yet."
|
||||
)
|
||||
else:
|
||||
self.alibi_slopes = None
|
||||
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if self.logits_soft_cap is not None:
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support logits soft cap yet."
|
||||
)
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("FlexAttention does not support kv sharing yet.")
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlexAttention does not support quantized kv-cache. Yet"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def view_as_4d(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""View a 3d tensor as 4D."""
|
||||
if tensor.ndim == 4:
|
||||
return tensor
|
||||
assert tensor.ndim == 3
|
||||
return tensor[None, :, :, :]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlexAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FLexAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlexAttentionImpl"
|
||||
)
|
||||
|
||||
enable_gqa = self.num_kv_heads != self.num_heads
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
# query = self.view_as_4d(query).permute(0, 2, 1, 3)
|
||||
# return torch.empty_like(query)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
if attn_metadata.sliding_window != self.sliding_window:
|
||||
attn_metadata.sliding_window = self.sliding_window
|
||||
if attn_metadata.direct_build:
|
||||
# TODO: Support skipping the computation of sliding window
|
||||
# in direct block mask building code path.
|
||||
logger.warning_once(
|
||||
"Using direct block mask building with sliding window, "
|
||||
"which is suboptimal now. Performance may be degraded."
|
||||
)
|
||||
# update mask mod in attention metadata
|
||||
attn_metadata.mask_mod = attn_metadata.get_mask_mod()
|
||||
attn_metadata.block_mask = attn_metadata._build_block_mask_direct()
|
||||
else:
|
||||
attn_metadata.block_mask = attn_metadata.build_block_mask()
|
||||
|
||||
if not attn_metadata.causal:
|
||||
assert self.attn_type == AttentionType.ENCODER_ONLY
|
||||
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key, value),
|
||||
)
|
||||
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
if (key_tensor.size(-2) > num_actual_tokens) or (
|
||||
value_tensor.size(-2) > num_actual_tokens
|
||||
):
|
||||
# In the encoder-only model with torch.compile,
|
||||
# qkv might be padded, which might cause exception.
|
||||
# see: https://github.com/vllm-project/vllm/pull/24872#discussion_r2353252290
|
||||
key_tensor = key_tensor[:, :, :num_actual_tokens, :]
|
||||
value_tensor = value_tensor[:, :, :num_actual_tokens, :]
|
||||
|
||||
else:
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
# View out the block_size dim
|
||||
key_cache = key_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
value_cache = value_cache.view(-1, self.num_kv_heads, self.head_size)
|
||||
query, key_tensor, value_tensor = map(
|
||||
lambda x: self.view_as_4d(x).permute(0, 2, 1, 3),
|
||||
(query, key_cache, value_cache),
|
||||
)
|
||||
|
||||
query = query[:, :, :num_actual_tokens, :]
|
||||
|
||||
# Doesn't work for now -> constraint violation
|
||||
# torch._dynamo.try_mark_dynamic(query, 2)
|
||||
|
||||
assert attn_metadata.block_mask is not None
|
||||
block_m, block_n = attn_metadata.block_mask.BLOCK_SIZE
|
||||
|
||||
kernel_options = get_kernel_options(
|
||||
query, block_m, block_n, attn_metadata.direct_build
|
||||
)
|
||||
out = flex_attention_compiled(
|
||||
query,
|
||||
key_tensor,
|
||||
value_tensor,
|
||||
attn_metadata.transformed_score_mod,
|
||||
attn_metadata.block_mask,
|
||||
self.scale,
|
||||
enable_gqa=enable_gqa,
|
||||
kernel_options=kernel_options,
|
||||
)
|
||||
|
||||
# Flex doesn't have an out variant today, rely on epilogue fusion
|
||||
out = out.permute(0, 2, 1, 3).squeeze(0)
|
||||
output[:num_actual_tokens, :, :].copy_(out)
|
||||
return output
|
||||
|
||||
|
||||
def get_kernel_options(
|
||||
query, block_m, block_n, use_direct_build: bool
|
||||
) -> dict[str, int | bool]:
|
||||
kernel_options: dict[str, int | bool] = {
|
||||
"FORCE_USE_FLEX_ATTENTION": True,
|
||||
}
|
||||
|
||||
def ensure_divisible(candidate: int, block_size: int) -> int:
|
||||
"""Pick a kernel block size that divides the logical block."""
|
||||
if block_size <= 0:
|
||||
return candidate
|
||||
candidate = min(candidate, block_size)
|
||||
if candidate <= 0:
|
||||
return block_size
|
||||
if block_size % candidate == 0:
|
||||
return candidate
|
||||
|
||||
candidate = math.gcd(candidate, block_size)
|
||||
if candidate <= 1:
|
||||
return block_size
|
||||
return candidate
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
kernel_options["BLOCK_M"] = 16
|
||||
kernel_options["BLOCK_N"] = 16
|
||||
kernel_options["IS_DIVISIBLE"] = False
|
||||
return kernel_options
|
||||
if use_direct_build:
|
||||
kernel_options["BLOCK_M"] = block_m
|
||||
kernel_options["BLOCK_N"] = block_n
|
||||
return kernel_options
|
||||
else:
|
||||
preferred_block = 32 if query.dtype == torch.float32 else 64
|
||||
block_lower_bound = 16
|
||||
|
||||
block_m_candidate = ensure_divisible(preferred_block, block_m)
|
||||
block_n_candidate = ensure_divisible(preferred_block, block_n)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device_props = torch.cuda.get_device_properties()
|
||||
max_shared_memory = device_props.shared_memory_per_block_optin
|
||||
if max_shared_memory < 144 * 1024:
|
||||
block_m_candidate = ensure_divisible(
|
||||
max(1, block_m_candidate // 2), block_m
|
||||
)
|
||||
block_n_candidate = ensure_divisible(
|
||||
max(1, block_n_candidate // 2), block_n
|
||||
)
|
||||
|
||||
block_m_candidate = max(block_m_candidate, block_lower_bound)
|
||||
block_n_candidate = max(block_n_candidate, block_lower_bound)
|
||||
|
||||
kernel_options["BLOCK_M"] = block_m_candidate
|
||||
kernel_options["BLOCK_N"] = block_n_candidate
|
||||
|
||||
return kernel_options
|
||||
387
v1/attention/backends/gdn_attn.py
Normal file
387
v1/attention/backends/gdn_attn.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Backend for GatedDeltaNet attention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class GDNAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]:
|
||||
return GDNAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class GDNAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_spec_decodes: int
|
||||
num_spec_decode_tokens: int
|
||||
num_actual_tokens: int
|
||||
|
||||
has_initial_state: torch.Tensor | None = None
|
||||
|
||||
spec_query_start_loc: torch.Tensor | None = None # shape: [num_spec_decodes + 1,]
|
||||
non_spec_query_start_loc: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes + 1,]
|
||||
)
|
||||
|
||||
spec_state_indices_tensor: torch.Tensor | None = None # shape: [batch, num_spec]
|
||||
non_spec_state_indices_tensor: torch.Tensor | None = (
|
||||
None # shape: [batch - num_spec_decodes,]
|
||||
)
|
||||
spec_sequence_masks: torch.Tensor | None = None # shape: [batch,]
|
||||
spec_token_indx: torch.Tensor | None = None
|
||||
non_spec_token_indx: torch.Tensor | None = None
|
||||
|
||||
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
if self.speculative_config:
|
||||
self.num_spec = self.speculative_config.num_speculative_tokens
|
||||
else:
|
||||
self.num_spec = 0
|
||||
self.use_spec_decode = self.num_spec > 0
|
||||
self._init_reorder_batch_threshold(1, self.use_spec_decode)
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs * (self.num_spec + 1),
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
self.spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs, self.num_spec + 1),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_sequence_masks = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
self.spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_token_indx = torch.empty(
|
||||
(self.decode_cudagraph_max_bs * (self.num_spec + 1),),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.non_spec_query_start_loc = torch.empty(
|
||||
(self.decode_cudagraph_max_bs + 1,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_accepted_tokens = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build( # type: ignore[override]
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
num_decode_draft_tokens_cpu: torch.Tensor | None = None,
|
||||
fast_build: bool = False,
|
||||
) -> GDNAttentionMetadata:
|
||||
m = common_attn_metadata
|
||||
|
||||
query_start_loc = m.query_start_loc
|
||||
context_lens = m.num_computed_tokens_cpu
|
||||
context_lens_tensor = context_lens.to(query_start_loc.device)
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
if (
|
||||
not self.use_spec_decode
|
||||
or num_decode_draft_tokens_cpu is None
|
||||
or num_decode_draft_tokens_cpu[num_decode_draft_tokens_cpu >= 0]
|
||||
.sum()
|
||||
.item()
|
||||
== 0
|
||||
):
|
||||
spec_sequence_masks = None
|
||||
num_spec_decodes = 0
|
||||
else:
|
||||
spec_sequence_masks = num_decode_draft_tokens_cpu >= 0
|
||||
num_spec_decodes = spec_sequence_masks.sum().item()
|
||||
if num_spec_decodes == 0:
|
||||
spec_sequence_masks = None
|
||||
else:
|
||||
spec_sequence_masks = spec_sequence_masks.to(
|
||||
query_start_loc.device, non_blocking=True
|
||||
)
|
||||
|
||||
if spec_sequence_masks is None:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(m, decode_threshold=1)
|
||||
)
|
||||
num_spec_decode_tokens = 0
|
||||
spec_token_indx = None
|
||||
non_spec_token_indx = None
|
||||
spec_state_indices_tensor = None
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[:, 0]
|
||||
spec_query_start_loc = None
|
||||
non_spec_query_start_loc = query_start_loc
|
||||
num_accepted_tokens = None
|
||||
else:
|
||||
query_lens = query_start_loc[1:] - query_start_loc[:-1]
|
||||
|
||||
non_spec_query_lens = query_lens[~spec_sequence_masks]
|
||||
num_decodes = (non_spec_query_lens == 1).sum().item()
|
||||
num_prefills = non_spec_query_lens.size(0) - num_decodes
|
||||
num_decode_tokens = num_decodes
|
||||
num_prefill_tokens = non_spec_query_lens.sum().item() - num_decode_tokens
|
||||
num_spec_decode_tokens = (
|
||||
query_lens.sum().item() - num_prefill_tokens - num_decode_tokens
|
||||
)
|
||||
|
||||
if num_prefills == 0 and num_decodes == 0:
|
||||
spec_token_size = min(
|
||||
num_spec_decodes * (self.num_spec + 1),
|
||||
query_start_loc[-1].item(),
|
||||
)
|
||||
spec_token_indx = torch.arange(
|
||||
spec_token_size,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
non_spec_token_indx = torch.empty(
|
||||
0, dtype=torch.int32, device=query_start_loc.device
|
||||
)
|
||||
spec_state_indices_tensor = m.block_table_tensor[:, : self.num_spec + 1]
|
||||
non_spec_state_indices_tensor = None
|
||||
spec_query_start_loc = query_start_loc
|
||||
non_spec_query_start_loc = None
|
||||
else:
|
||||
spec_token_masks = torch.repeat_interleave(
|
||||
spec_sequence_masks, query_lens
|
||||
)
|
||||
index = torch.argsort(spec_token_masks)
|
||||
num_non_spec_tokens = num_prefill_tokens + num_decode_tokens
|
||||
non_spec_token_indx = index[:num_non_spec_tokens]
|
||||
spec_token_indx = index[num_non_spec_tokens:]
|
||||
|
||||
spec_state_indices_tensor = m.block_table_tensor[
|
||||
spec_sequence_masks, : self.num_spec + 1
|
||||
]
|
||||
non_spec_state_indices_tensor = m.block_table_tensor[
|
||||
~spec_sequence_masks, 0
|
||||
]
|
||||
|
||||
spec_query_start_loc = torch.zeros(
|
||||
num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[spec_sequence_masks], dim=0, out=spec_query_start_loc[1:]
|
||||
)
|
||||
non_spec_query_start_loc = torch.zeros(
|
||||
query_lens.size(0) - num_spec_decodes + 1,
|
||||
dtype=torch.int32,
|
||||
device=query_start_loc.device,
|
||||
)
|
||||
torch.cumsum(
|
||||
query_lens[~spec_sequence_masks],
|
||||
dim=0,
|
||||
out=non_spec_query_start_loc[1:],
|
||||
)
|
||||
|
||||
assert num_accepted_tokens is not None
|
||||
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
|
||||
|
||||
if num_prefills > 0:
|
||||
has_initial_state = context_lens_tensor > 0
|
||||
if spec_sequence_masks is not None:
|
||||
has_initial_state = has_initial_state[~spec_sequence_masks]
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(non_spec_query_start_loc)
|
||||
)
|
||||
else:
|
||||
has_initial_state = None
|
||||
num_actual_tokens = (
|
||||
num_prefill_tokens + num_decode_tokens + num_spec_decode_tokens
|
||||
)
|
||||
|
||||
# prepare tensors for cudagraph
|
||||
#
|
||||
# With speculative decoding, the xgrammar backend may rollback tokens
|
||||
# and causing some sequences has less draft tokens than self.num_spec.
|
||||
#
|
||||
# In above cases, the max possible batch size for n tokens, can be
|
||||
# min(n, cudagraph_max_bs).
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_decodes == 0
|
||||
and num_spec_decodes <= self.decode_cudagraph_max_bs
|
||||
and num_spec_decode_tokens <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = min(self.decode_cudagraph_max_bs, num_actual_tokens)
|
||||
|
||||
self.spec_state_indices_tensor[:num_spec_decodes].copy_(
|
||||
spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
spec_state_indices_tensor = self.spec_state_indices_tensor[:batch_size]
|
||||
spec_state_indices_tensor[num_spec_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.spec_sequence_masks[:num_spec_decodes].copy_(
|
||||
spec_sequence_masks, non_blocking=True
|
||||
)
|
||||
spec_sequence_masks = self.spec_sequence_masks[:batch_size]
|
||||
spec_sequence_masks[num_spec_decodes:].fill_(False)
|
||||
|
||||
assert non_spec_token_indx is not None and spec_token_indx is not None
|
||||
self.non_spec_token_indx[: non_spec_token_indx.size(0)].copy_(
|
||||
non_spec_token_indx, non_blocking=True
|
||||
)
|
||||
non_spec_token_indx = self.non_spec_token_indx[
|
||||
: non_spec_token_indx.size(0)
|
||||
]
|
||||
|
||||
self.spec_token_indx[: spec_token_indx.size(0)].copy_(
|
||||
spec_token_indx, non_blocking=True
|
||||
)
|
||||
spec_token_indx = self.spec_token_indx[: spec_token_indx.size(0)]
|
||||
|
||||
self.spec_query_start_loc[: num_spec_decodes + 1].copy_(
|
||||
spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
spec_num_query_tokens = spec_query_start_loc[-1] # type: ignore[index]
|
||||
spec_query_start_loc = self.spec_query_start_loc[: batch_size + 1]
|
||||
spec_query_start_loc[num_spec_decodes + 1 :].fill_(spec_num_query_tokens)
|
||||
|
||||
self.num_accepted_tokens[:num_spec_decodes].copy_(
|
||||
num_accepted_tokens, non_blocking=True
|
||||
)
|
||||
num_accepted_tokens = self.num_accepted_tokens[:batch_size]
|
||||
num_accepted_tokens[num_spec_decodes:].fill_(1)
|
||||
|
||||
if (
|
||||
self.use_full_cuda_graph
|
||||
and num_prefills == 0
|
||||
and num_spec_decodes == 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
):
|
||||
num_actual_tokens = self.vllm_config.pad_for_cudagraph(m.num_actual_tokens)
|
||||
batch_size = num_actual_tokens
|
||||
|
||||
self.non_spec_state_indices_tensor[:num_decodes].copy_(
|
||||
non_spec_state_indices_tensor, non_blocking=True
|
||||
)
|
||||
non_spec_state_indices_tensor = self.non_spec_state_indices_tensor[
|
||||
:batch_size
|
||||
]
|
||||
non_spec_state_indices_tensor[num_decodes:].fill_(PAD_SLOT_ID)
|
||||
|
||||
self.non_spec_query_start_loc[: num_decodes + 1].copy_(
|
||||
non_spec_query_start_loc, non_blocking=True
|
||||
)
|
||||
non_spec_num_query_tokens = non_spec_query_start_loc[-1] # type: ignore[index]
|
||||
non_spec_query_start_loc = self.non_spec_query_start_loc[: batch_size + 1]
|
||||
non_spec_query_start_loc[num_decodes + 1 :].fill_(non_spec_num_query_tokens)
|
||||
|
||||
attn_metadata = GDNAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_spec_decodes=num_spec_decodes,
|
||||
num_spec_decode_tokens=num_spec_decode_tokens,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
has_initial_state=has_initial_state,
|
||||
spec_query_start_loc=spec_query_start_loc,
|
||||
non_spec_query_start_loc=non_spec_query_start_loc,
|
||||
spec_state_indices_tensor=spec_state_indices_tensor,
|
||||
non_spec_state_indices_tensor=non_spec_state_indices_tensor,
|
||||
spec_sequence_masks=spec_sequence_masks,
|
||||
spec_token_indx=spec_token_indx,
|
||||
non_spec_token_indx=non_spec_token_indx,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert (
|
||||
m.num_reqs <= self.decode_cudagraph_max_bs
|
||||
and m.num_actual_tokens <= self.decode_cudagraph_max_bs
|
||||
), (
|
||||
f"GDN only supports decode-only full CUDAGraph capture. "
|
||||
f"Make sure batch size ({m.num_reqs}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs}), "
|
||||
f"and number of tokens ({m.num_actual_tokens}) <= "
|
||||
f"cudagraph capture sizes ({self.decode_cudagraph_max_bs})."
|
||||
)
|
||||
|
||||
num_accepted_tokens = torch.diff(m.query_start_loc)
|
||||
num_decode_draft_tokens_cpu = (num_accepted_tokens - 1).cpu()
|
||||
m.num_computed_tokens_cpu = m.seq_lens_cpu - num_accepted_tokens.cpu()
|
||||
|
||||
return self.build(0, m, num_accepted_tokens, num_decode_draft_tokens_cpu)
|
||||
74
v1/attention/backends/linear_attn.py
Normal file
74
v1/attention/backends/linear_attn.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class LinearAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]:
|
||||
return LinearAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinearAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class LinearAttentionMetadataBuilder(AttentionMetadataBuilder[LinearAttentionMetadata]):
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> LinearAttentionMetadata:
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
attn_metadata = LinearAttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens=seq_lens,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
)
|
||||
return attn_metadata
|
||||
165
v1/attention/backends/mamba1_attn.py
Normal file
165
v1/attention/backends/mamba1_attn.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
|
||||
class Mamba1AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]:
|
||||
return Mamba1AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba1AttentionMetadata:
|
||||
query_start_loc_p: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_padded_decodes: int
|
||||
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
|
||||
|
||||
class Mamba1AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba1AttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba1AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states_p = None
|
||||
query_start_loc_p = None
|
||||
padded_decodes = num_decodes
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
# TODO(@Josephasafg) Mamba1 and Mamba2 have a lot of code in common here.
|
||||
# We should consolidate this code
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
if num_prefills > 0:
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
padded_decodes = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:padded_decodes]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:padded_decodes
|
||||
]
|
||||
block_idx_last_scheduled_token[num_decodes:] = 0
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:padded_decodes
|
||||
]
|
||||
block_idx_last_computed_token[num_decodes:] = 0
|
||||
|
||||
return Mamba1AttentionMetadata(
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_padded_decodes=padded_decodes,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
354
v1/attention/backends/mamba2_attn.py
Normal file
354
v1/attention/backends/mamba2_attn.py
Normal file
@@ -0,0 +1,354 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
def compute_varlen_chunk_metadata(
|
||||
query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Build chunk-aligned, variable-length metadata used by Mamba2 SSD kernels.
|
||||
|
||||
Given per-sequence cumulative token starts `query_start_loc` of shape [B+1]
|
||||
and a physical `chunk_size`, returns three tensors on the same device:
|
||||
- cu_chunk_seqlens: (nchunks+1,) int32 exclusive prefix-sum of
|
||||
logical-chunk lengths (each logical chunk never crosses a sequence or
|
||||
physical-chunk boundary).
|
||||
- last_chunk_indices: (B,) int32 index of the last logical chunk
|
||||
for each sequence (=-1 for empty sequences).
|
||||
- seq_idx_chunks: (nchunks,) int32 sequence index for each logical
|
||||
chunk in order.
|
||||
|
||||
This is intentionally lightweight and CPU-side; it mirrors the metadata
|
||||
produced by the V1 Mamba2 meta-data builder and is exported so tests
|
||||
(and other callers) can avoid duplicating the logic.
|
||||
"""
|
||||
assert query_start_loc.ndim == 1, "query_start_loc must be 1-D [B+1]"
|
||||
assert int(query_start_loc[0].item()) == 0, "query_start_loc[0] must be 0"
|
||||
device = query_start_loc.device
|
||||
|
||||
qsl64 = query_start_loc.to(torch.int64)
|
||||
starts = qsl64[:-1].tolist()
|
||||
ends = qsl64[1:].tolist()
|
||||
total = int(qsl64[-1].item())
|
||||
|
||||
chunk_lens: list[int] = []
|
||||
seq_idx_chunks: list[int] = []
|
||||
last_chunk_indices: list[int] = [-1] * len(starts)
|
||||
|
||||
for b, (s, e) in enumerate(zip(starts, ends)):
|
||||
if e <= s:
|
||||
# empty sequence
|
||||
continue
|
||||
pos = s
|
||||
while pos < e:
|
||||
# split at both sequence boundaries and physical chunk boundaries
|
||||
room = chunk_size - (pos % chunk_size)
|
||||
take = min(room, e - pos)
|
||||
chunk_lens.append(int(take))
|
||||
seq_idx_chunks.append(b)
|
||||
last_chunk_indices[b] = len(chunk_lens) - 1
|
||||
pos += take
|
||||
|
||||
# Exclusive prefix sum over logical-chunk lengths
|
||||
if chunk_lens:
|
||||
cu_chunk_seqlens = torch.tensor(
|
||||
[0] + list(itertools.accumulate(chunk_lens)),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
# Final boundary must equal total tokens
|
||||
assert int(cu_chunk_seqlens[-1].item()) == total
|
||||
else:
|
||||
cu_chunk_seqlens = torch.tensor([0], device=device, dtype=torch.int32)
|
||||
|
||||
last_chunk_indices_t = (
|
||||
torch.tensor(last_chunk_indices, device=device, dtype=torch.int32)
|
||||
if len(starts) > 0
|
||||
else torch.empty((0,), device=device, dtype=torch.int32)
|
||||
)
|
||||
seq_idx_chunks_t = torch.tensor(seq_idx_chunks, device=device, dtype=torch.int32)
|
||||
return cu_chunk_seqlens, last_chunk_indices_t, seq_idx_chunks_t
|
||||
|
||||
|
||||
class Mamba2AttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]:
|
||||
return Mamba2AttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mamba2AttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
query_start_loc_p: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
prep_initial_states: bool
|
||||
chunk_size: int
|
||||
|
||||
# The following tensors only contain prefill requests and will be None if
|
||||
# the batch has no prefill request.
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
seq_idx_p: torch.Tensor | None
|
||||
|
||||
# cu_chunk_seqlen_p is a tensor of shape (nchunks+1,) that contains, for
|
||||
# each chunk, its offests into the varlen sequence dimension. It is defined
|
||||
# such that the i-th chunk contains tokens from cu_chunk_seqlen_p[i] to
|
||||
# cu_chunk_seqlen_p[i+1].
|
||||
cu_chunk_seqlen_p: torch.Tensor | None
|
||||
|
||||
# last_chunk_indices_p is a tensor of shape (batch,) that contains the
|
||||
# index of the last chunk for every sequence in the (prefill) batch.
|
||||
last_chunk_indices_p: torch.Tensor | None
|
||||
|
||||
state_indices_tensor: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_scheduled_token: torch.Tensor # shape: [batch,]
|
||||
block_idx_first_scheduled_token_p: torch.Tensor # shape: [batch,]
|
||||
block_idx_last_computed_token: torch.Tensor # shape: [batch,]
|
||||
num_computed_tokens_p: torch.Tensor # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[Mamba2AttentionMetadata]
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
self.chunk_size = vllm_config.model_config.get_mamba_chunk_size()
|
||||
assert self.chunk_size is not None, (
|
||||
"chunk_size needs to be set in the model config for Mamba2 models"
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> Mamba2AttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
|
||||
query_start_loc_p = None
|
||||
seq_idx_p = None
|
||||
cu_chunk_seqlen_p = None
|
||||
last_chunk_indices_p = None
|
||||
|
||||
# Need flags to indicate if there are initial states
|
||||
has_initial_states_p = None
|
||||
prep_initial_states = False
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
num_computed_tokens, num_computed_tokens_p = None, None
|
||||
block_idx_first_scheduled_token = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
# Return a tensor of shape (#requests, #max blocks)
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor
|
||||
# Additional cache-related varaiables:
|
||||
mamba_block_size = self.kv_cache_spec.block_size
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
(
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
) = self._compute_prefix_caching_block_indices(
|
||||
common_attn_metadata, mamba_block_size
|
||||
)
|
||||
else:
|
||||
# Always return just a single block per each request:
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
# Additional cache-related varaiables:
|
||||
block_idx_last_scheduled_token = None
|
||||
block_idx_last_computed_token = None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
# Compute seq_idx for prefill only
|
||||
if num_prefills > 0:
|
||||
# [batch,]
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
prep_initial_states = torch.any(has_initial_states_cpu).item()
|
||||
has_initial_states_p = has_initial_states_cpu.to(
|
||||
common_attn_metadata.query_start_loc.device
|
||||
)
|
||||
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
assert num_computed_tokens is not None
|
||||
num_computed_tokens_p = num_computed_tokens[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
assert block_idx_first_scheduled_token is not None
|
||||
block_idx_first_scheduled_token_p = block_idx_first_scheduled_token[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
num_computed_tokens_p_cpu = common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
query_start_loc_p_cpu = (
|
||||
common_attn_metadata.query_start_loc_cpu[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
# The code below carefully constructs the chunks such that:
|
||||
# 1. Chunks contain tokens from a *single* sequence only.
|
||||
# 2. For every sequence, we are guaranteed that we can
|
||||
# retrieve the mamba state *every* chunk_size tokens.
|
||||
# Constraint (1) dramatically simplifies the mamba2 kernels.
|
||||
# Constraint (2) dramatically simplifies the implementation
|
||||
# of prefix caching for mamba2 (wip). We need to take care
|
||||
# of the interaction with chunked prefill in order to
|
||||
# satisfy constraint (2).
|
||||
# TODO (tdoublep): This code could probably be optimized.
|
||||
cu_chunk_seqlen = []
|
||||
seq_idx = []
|
||||
last_chunk_indices = []
|
||||
seqlen_pos = 0
|
||||
for req_idx in range(num_prefills):
|
||||
this_num_computed = num_computed_tokens_p_cpu[req_idx].item()
|
||||
this_new_tokens = (
|
||||
query_start_loc_p_cpu[req_idx + 1].item()
|
||||
- query_start_loc_p_cpu[req_idx].item()
|
||||
)
|
||||
|
||||
# if computed tokens are not chunk-aligned, use the first
|
||||
# chunk to finish it off
|
||||
if this_num_computed % self.chunk_size != 0:
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
# how many tokens to finish the chunk?
|
||||
chunk_len = (
|
||||
cdiv(this_num_computed, self.chunk_size) * self.chunk_size
|
||||
- this_num_computed
|
||||
)
|
||||
# we can only use at most this_new_tokens
|
||||
chunk_len = min(chunk_len, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
n_chunks = cdiv(this_new_tokens, self.chunk_size)
|
||||
for chunk in range(n_chunks):
|
||||
seq_idx.append(req_idx)
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
chunk_len = min(self.chunk_size, this_new_tokens)
|
||||
seqlen_pos += chunk_len
|
||||
this_new_tokens -= chunk_len
|
||||
|
||||
assert this_new_tokens == 0
|
||||
last_chunk_indices.append(len(cu_chunk_seqlen) - 1)
|
||||
|
||||
cu_chunk_seqlen.append(seqlen_pos)
|
||||
|
||||
seq_idx_p = torch.as_tensor(
|
||||
seq_idx, device=query_start_loc_p.device, dtype=torch.int32
|
||||
)
|
||||
cu_chunk_seqlen_p = torch.as_tensor(
|
||||
cu_chunk_seqlen, device=query_start_loc_p.device, dtype=torch.int32
|
||||
)
|
||||
last_chunk_indices_p = torch.as_tensor(
|
||||
last_chunk_indices, device=query_start_loc_p.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
)
|
||||
|
||||
elif (
|
||||
num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
# Pad state tensor for CUDA graph
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.block_idx_last_scheduled_token[:num_decodes].copy_(
|
||||
block_idx_last_scheduled_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_scheduled_token = self.block_idx_last_scheduled_token[
|
||||
:num_input_tokens
|
||||
]
|
||||
block_idx_last_scheduled_token[num_decodes:] = 0
|
||||
|
||||
self.block_idx_last_computed_token[:num_decodes].copy_(
|
||||
block_idx_last_computed_token, non_blocking=True
|
||||
)
|
||||
block_idx_last_computed_token = self.block_idx_last_computed_token[
|
||||
:num_input_tokens
|
||||
]
|
||||
block_idx_last_computed_token[num_decodes:] = 0
|
||||
|
||||
attn_metadata = Mamba2AttentionMetadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
seq_lens=seq_lens,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=self.chunk_size,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
seq_idx_p=seq_idx_p,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
cu_chunk_seqlen_p=cu_chunk_seqlen_p,
|
||||
last_chunk_indices_p=last_chunk_indices_p,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token,
|
||||
block_idx_first_scheduled_token_p=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_computed_token=block_idx_last_computed_token,
|
||||
num_computed_tokens_p=num_computed_tokens_p,
|
||||
)
|
||||
return attn_metadata
|
||||
115
v1/attention/backends/mamba_attn.py
Normal file
115
v1/attention/backends/mamba_attn.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import abc
|
||||
from typing import ClassVar, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec, MambaSpec
|
||||
|
||||
M = TypeVar("M")
|
||||
|
||||
|
||||
class BaseMambaAttentionMetadataBuilder(AttentionMetadataBuilder[M], abc.ABC):
|
||||
reorder_batch_threshold: int = 1
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
assert isinstance(kv_cache_spec, MambaSpec)
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.decode_cudagraph_max_bs = min(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.compilation_config.max_cudagraph_capture_size,
|
||||
)
|
||||
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(
|
||||
self.decode_cudagraph_max_bs,
|
||||
cdiv(
|
||||
self.vllm_config.model_config.max_model_len,
|
||||
self.kv_cache_spec.block_size,
|
||||
),
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_scheduled_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.block_idx_last_computed_token = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.state_indices_tensor = torch.empty(
|
||||
(self.decode_cudagraph_max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> M:
|
||||
"""
|
||||
This method builds the metadata for full cudagraph capture.
|
||||
Currently, only decode is supported for full cudagraphs with Mamba.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
|
||||
assert m.num_reqs == m.num_actual_tokens, (
|
||||
"Mamba only supports decode-only full CUDAGraph capture. "
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
)
|
||||
|
||||
m.max_query_len = 1 # decode-only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
def _compute_prefix_caching_block_indices(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
mamba_block_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
num_computed_tokens = common_attn_metadata.num_computed_tokens_cpu.to(
|
||||
self.device
|
||||
)
|
||||
# Block index of the last computed token
|
||||
block_idx_last_computed_token = cdiv(num_computed_tokens, mamba_block_size) - 1
|
||||
# which is <= block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = (
|
||||
cdiv(num_computed_tokens + 1, mamba_block_size) - 1
|
||||
)
|
||||
# which is <= block index of the last scheduled token
|
||||
block_idx_last_scheduled_token = (
|
||||
cdiv(common_attn_metadata.seq_lens, mamba_block_size) - 1
|
||||
)
|
||||
# -1 in case it's non-computed and causes later issues with indexing
|
||||
block_idx_last_computed_token = block_idx_last_computed_token.clamp(min=0)
|
||||
|
||||
return (
|
||||
block_idx_last_computed_token,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
)
|
||||
0
v1/attention/backends/mla/__init__.py
Normal file
0
v1/attention/backends/mla/__init__.py
Normal file
BIN
v1/attention/backends/mla/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/attention/backends/mla/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/attention/backends/mla/__pycache__/common.cpython-312.pyc
Normal file
BIN
v1/attention/backends/mla/__pycache__/common.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
v1/attention/backends/mla/__pycache__/flashmla.cpython-312.pyc
Normal file
BIN
v1/attention/backends/mla/__pycache__/flashmla.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
v1/attention/backends/mla/__pycache__/indexer.cpython-312.pyc
Normal file
BIN
v1/attention/backends/mla/__pycache__/indexer.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
v1/attention/backends/mla/__pycache__/triton_mla.cpython-312.pyc
Normal file
BIN
v1/attention/backends/mla/__pycache__/triton_mla.cpython-312.pyc
Normal file
Binary file not shown.
2200
v1/attention/backends/mla/common.py
Normal file
2200
v1/attention/backends/mla/common.py
Normal file
File diff suppressed because it is too large
Load Diff
275
v1/attention/backends/mla/cutlass_mla.py
Normal file
275
v1/attention/backends/mla/cutlass_mla.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
# enable full CUDA Graph support for decode-only capture
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
|
||||
class CutlassMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CUTLASS_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["CutlassMLAImpl"]:
|
||||
return CutlassMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["CutlassMLAMetadataBuilder"]:
|
||||
return CutlassMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
|
||||
class SM100Workspace:
|
||||
def __init__(self, initial_workspace_size):
|
||||
self._workspace_buf = torch.empty(
|
||||
initial_workspace_size, device="cuda", dtype=torch.uint8
|
||||
)
|
||||
|
||||
self._block_size = 128 # Forced to 128
|
||||
|
||||
# Pre-compute sm_count to avoid recomputing it. Use device 0 as a proxy
|
||||
# (assumes all devices are similar)
|
||||
properties = torch.cuda.get_device_properties(torch.device("cuda:0"))
|
||||
self._sm_count = properties.multi_processor_count
|
||||
|
||||
def get_buf(self):
|
||||
return self._workspace_buf
|
||||
|
||||
def ensure_size(self, attn_metadata: MLACommonMetadata, num_kv_splits: int):
|
||||
batch_size = attn_metadata.num_reqs
|
||||
max_seq_len = attn_metadata.max_query_len
|
||||
|
||||
workspace_size = ops.sm100_cutlass_mla_get_workspace_size(
|
||||
max_seq_len * self._block_size,
|
||||
batch_size,
|
||||
self._sm_count,
|
||||
num_kv_splits=num_kv_splits,
|
||||
)
|
||||
|
||||
if self._workspace_buf.shape[0] < workspace_size:
|
||||
self._workspace_buf.resize_(workspace_size)
|
||||
|
||||
|
||||
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
||||
|
||||
MAX_HEADS = 128
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
q_pad_num_heads=MAX_HEADS,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CutlassMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CutlassMLAImpl"
|
||||
)
|
||||
|
||||
# TODO: Currently, num_kv_splits is limited to 16 to avoid hanging
|
||||
# issues. In case the code hangs, use:
|
||||
# FORCE_NUM_KV_SPLITS=1
|
||||
force_num_kv_splits = os.environ.get("FORCE_NUM_KV_SPLITS", None)
|
||||
if force_num_kv_splits:
|
||||
logger.debug_once("Forcing num_kv_splits to %d", int(force_num_kv_splits))
|
||||
self._num_kv_splits = int(force_num_kv_splits)
|
||||
else:
|
||||
self._num_kv_splits = -1 # => Auto-detect
|
||||
|
||||
# Share workspace buffer across all executions
|
||||
self._workspace = g_sm100_workspace
|
||||
|
||||
def _sm100_cutlass_mla_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
sm_scale: float,
|
||||
num_kv_splits: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert q_nope.ndim == 3, f"q_nope must be a 3D tensor, but got {q_nope.ndim}"
|
||||
assert q_pe.ndim == 3, f"q_pe must be a 3D tensor, but got {q_pe.ndim}"
|
||||
assert kv_c_and_k_pe_cache.ndim == 3, (
|
||||
"kv_c_and_k_pe_cache must be a 3D tensor, but got {}".format(
|
||||
kv_c_and_k_pe_cache.ndim
|
||||
)
|
||||
)
|
||||
|
||||
B_q, H, D_q_nope = q_nope.shape
|
||||
B_q_2, H_2, D_q_pe = q_pe.shape
|
||||
assert (B_q == B_q_2) and (H == H_2)
|
||||
|
||||
_, PAGE_SIZE, D_ckv = kv_c_and_k_pe_cache.shape
|
||||
|
||||
D_latent = 512
|
||||
D_rope = 64
|
||||
assert D_q_nope == D_latent
|
||||
assert D_q_pe == D_rope
|
||||
assert D_ckv == D_latent + D_rope
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
assert B_block_table == B_q
|
||||
assert block_num > 0, f"block num must be greater than 0, got {block_num}"
|
||||
assert block_num % (128 / PAGE_SIZE) == 0
|
||||
|
||||
assert q_nope.dtype in (torch.float16, torch.bfloat16, torch.float8_e4m3fn), (
|
||||
f"q_nope.dtype needs to be fp16 or bf16 or e4m3 but got {q_nope.dtype}."
|
||||
)
|
||||
assert q_nope.dtype == q_pe.dtype == kv_c_and_k_pe_cache.dtype
|
||||
assert seq_lens.dtype == torch.int32, (
|
||||
f"seq_lens.dtype needs to be int32 but got {seq_lens.dtype}."
|
||||
)
|
||||
assert page_table.dtype == torch.int32, (
|
||||
f"page_table.dtype needs to be int32 but got {page_table.dtype}."
|
||||
)
|
||||
|
||||
dtype = (
|
||||
torch.bfloat16
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype)
|
||||
else q_nope.dtype
|
||||
)
|
||||
out = q_nope.new_empty((B_q, MAX_HEADS, D_latent), dtype=dtype)
|
||||
lse = (
|
||||
torch.empty((B_q, MAX_HEADS), dtype=torch.float32, device=q_nope.device)
|
||||
if self.need_to_return_lse_for_decode
|
||||
else torch.Tensor()
|
||||
)
|
||||
|
||||
ops.sm100_cutlass_mla_decode(
|
||||
out,
|
||||
lse,
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
seq_lens,
|
||||
page_table,
|
||||
workspace,
|
||||
sm_scale,
|
||||
num_kv_splits,
|
||||
)
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
|
||||
return out, lse
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
# Adjust workspace size (if necessary)
|
||||
self._workspace.ensure_size(attn_metadata, self._num_kv_splits)
|
||||
|
||||
# Run MLA
|
||||
o, lse = self._sm100_cutlass_mla_decode(
|
||||
q_nope,
|
||||
q_pe,
|
||||
kv_c_and_k_pe_cache,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.block_table,
|
||||
self._workspace.get_buf(),
|
||||
self.scale,
|
||||
self._num_kv_splits,
|
||||
)
|
||||
|
||||
return o, (lse if self.need_to_return_lse_for_decode else None)
|
||||
337
v1/attention/backends/mla/flashattn_mla.py
Normal file
337
v1/attention/backends/mla/flashattn_mla.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
||||
return FlashAttnMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
|
||||
return FlashAttnMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 9
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if not flash_attn_supports_mla():
|
||||
return "FlashAttention MLA not supported on this device"
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
query_start_loc: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
||||
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config,
|
||||
device,
|
||||
FlashAttnMLAMetadata,
|
||||
supports_dcp_with_varlen=True,
|
||||
)
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.fa_aot_schedule = get_flash_attn_version() == 3
|
||||
|
||||
self.use_full_cuda_graph = (
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
)
|
||||
self.max_cudagraph_size = self.compilation_config.max_cudagraph_capture_size
|
||||
|
||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
self.max_num_splits = 1
|
||||
|
||||
def _schedule_decode(
|
||||
self,
|
||||
num_reqs,
|
||||
cu_query_lens,
|
||||
max_query_len,
|
||||
seqlens,
|
||||
max_seq_len,
|
||||
causal,
|
||||
max_num_splits,
|
||||
):
|
||||
if self.fa_aot_schedule:
|
||||
return get_scheduler_metadata(
|
||||
batch_size=num_reqs,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_seq_len,
|
||||
num_heads_q=self.num_heads * self.dcp_world_size,
|
||||
num_heads_kv=1,
|
||||
headdim=self.mla_dims.qk_rope_head_dim,
|
||||
cache_seqlens=seqlens,
|
||||
qkv_dtype=self.kv_cache_spec.dtype,
|
||||
headdim_v=self.mla_dims.kv_lora_rank,
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
num_splits=max_num_splits,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
|
||||
# For Flash Attention MLA + full cudagraph
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and num_decode_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
if vllm_is_batch_invariant():
|
||||
max_num_splits = 1
|
||||
|
||||
scheduler_metadata = self._schedule_decode(
|
||||
num_reqs=seq_lens_cpu.numel(),
|
||||
cu_query_lens=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
seqlens=seq_lens_device,
|
||||
max_seq_len=max_seq_len,
|
||||
causal=True,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
# Ensure the persistent buffer is large enough
|
||||
assert n <= self.scheduler_metadata.shape[0], (
|
||||
f"Scheduler metadata size {n} exceeds buffer size "
|
||||
+ f"{self.scheduler_metadata.shape[0]}"
|
||||
)
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
metadata = FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
query_start_loc=query_start_loc_device,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttnMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttnMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q_nope, q_pe = q
|
||||
else:
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
|
||||
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
||||
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
|
||||
|
||||
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
|
||||
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
|
||||
# to prevent invalid grid configuration during graph capture.
|
||||
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
||||
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q=q_pe,
|
||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||
q_v=q_nope,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
||||
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
||||
seqused_k=attn_metadata.decode.seq_lens,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
return_softmax_lse=self.need_to_return_lse_for_decode,
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.max_num_splits,
|
||||
cp_world_size=self.dcp_world_size,
|
||||
cp_rank=self.dcp_rank,
|
||||
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
|
||||
)
|
||||
|
||||
if self.need_to_return_lse_for_decode:
|
||||
o, lse = attn_out
|
||||
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
|
||||
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
|
||||
else:
|
||||
o = attn_out
|
||||
return o, None
|
||||
171
v1/attention/backends/mla/flashinfer_mla.py
Normal file
171
v1/attention/backends/mla/flashinfer_mla.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport, KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE = 128 * 1024 * 1024
|
||||
|
||||
|
||||
class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
|
||||
|
||||
class FlashInferMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHINFER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashInferMLAImpl"]:
|
||||
return FlashInferMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashInferMLAMetadataBuilder"]:
|
||||
return FlashInferMLAMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major == 10
|
||||
|
||||
@classmethod
|
||||
def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None":
|
||||
return "HND"
|
||||
|
||||
|
||||
g_fi_workspace = torch.zeros(
|
||||
FLASHINFER_MLA_WORKSPACE_BUFFER_SIZE,
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
|
||||
class FlashInferMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashInferMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashInferMLAImpl"
|
||||
)
|
||||
|
||||
self._workspace_buffer = g_fi_workspace
|
||||
self.bmm1_scale: float | None = None
|
||||
self.bmm2_scale: float | None = None
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if isinstance(q, tuple):
|
||||
q_nope, q_pe = q
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# trtllm API requires extra dimension q_len_per_request for MTP
|
||||
if attn_metadata.num_decode_tokens % attn_metadata.num_decodes != 0:
|
||||
logger.warning_once(
|
||||
"""FlashInferMLAImpl got a query of uneven length.
|
||||
This usually indicates an issue in batch reordering
|
||||
or incorrect setup in dummy_run."""
|
||||
)
|
||||
q = q.unsqueeze(1)
|
||||
else:
|
||||
q = q.view(attn_metadata.num_decodes, -1, q.shape[-2], q.shape[-1])
|
||||
|
||||
if self.bmm1_scale is None:
|
||||
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
|
||||
if self.bmm2_scale is None:
|
||||
self.bmm2_scale = layer._v_scale_float
|
||||
|
||||
o = trtllm_batch_decode_with_kv_cache_mla(
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache.unsqueeze(1),
|
||||
workspace_buffer=self._workspace_buffer,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
seq_lens=attn_metadata.decode.seq_lens,
|
||||
max_seq_len=attn_metadata.max_seq_len,
|
||||
bmm1_scale=self.bmm1_scale,
|
||||
bmm2_scale=self.bmm2_scale,
|
||||
)
|
||||
|
||||
# Flatten the output for consistent shape
|
||||
o = o.view(-1, o.shape[-2], o.shape[-1])
|
||||
|
||||
# TODO: Return LSE pending support from Flashinfer API:
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/1566
|
||||
return o, None
|
||||
314
v1/attention/backends/mla/flashmla.py
Normal file
314
v1/attention/backends/mla/flashmla.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionLayer, AttentionType, MultipleOf
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@classmethod
|
||||
def supports_combination(
|
||||
cls,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: CacheDType | None,
|
||||
block_size: int,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> str | None:
|
||||
if use_sparse:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_sparse_supported
|
||||
|
||||
return is_flashmla_sparse_supported()[1]
|
||||
else:
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
|
||||
return is_flashmla_dense_supported()[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
tile_scheduler_metadata: torch.Tensor
|
||||
num_splits: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata[FlashMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.UNIFORM
|
||||
reorder_batch_threshold: int = 128 # process small prefills with decode pathway
|
||||
# ^ TODO(matt): tune this
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, FlashMLAMetadata
|
||||
)
|
||||
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
self.is_fp8_kvcache = vllm_config.cache_config.cache_dtype.startswith("fp8")
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.cg_buf_tile_scheduler_metadata = torch.zeros(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> FlashMLADecodeMetadata:
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
# we use the max but all should be the same due to uniform length requirement
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
num_q_tokens_per_head_k = max_query_len * self.num_q_heads // 1
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
seq_lens_device,
|
||||
num_q_tokens_per_head_k,
|
||||
1, # MQA for the decode path
|
||||
is_fp8_kvcache=self.is_fp8_kvcache,
|
||||
)
|
||||
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert sm_parts <= self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
tile_scheduler_metadata_view = self.cg_buf_tile_scheduler_metadata[
|
||||
:sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_view.copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = tile_scheduler_metadata_view
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
# Num splits needs to monotonically increasing
|
||||
# (with: https://github.com/vllm-project/FlashMLA/pull/3, otherwise
|
||||
# it needs to monotonically increasing by 1)
|
||||
self.cg_buf_num_splits[n:].fill_(num_splits[-1])
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
is_supported, reason = is_flashmla_dense_supported()
|
||||
assert is_supported, reason
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl"
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
# TODO: (zyongye) decode function for mla here
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
# mypy assertion: q is now always a tensor
|
||||
assert isinstance(q, torch.Tensor)
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
q = reshape_query_for_spec_decode(q, num_decodes)
|
||||
|
||||
tile_scheduler_metadata = attn_metadata.decode.tile_scheduler_metadata
|
||||
num_splits = attn_metadata.decode.num_splits
|
||||
if vllm_is_batch_invariant():
|
||||
device = q.device
|
||||
dtype = torch.int32
|
||||
|
||||
B = q.shape[0]
|
||||
# block_table shape: [batch_size, max_num_blocks_per_seq]
|
||||
# The number of blocks per sequence is in the second dimension
|
||||
topk = attn_metadata.decode.block_table.shape[-1]
|
||||
B_TOPK = 64
|
||||
assert topk % B_TOPK == 0, f"topk ({topk}) must be divisible by {B_TOPK}"
|
||||
end_block_idx = topk // B_TOPK
|
||||
|
||||
# Single partition => num_sm_parts = 1
|
||||
# TileSchedulerMetaDataSize = 8, layout:
|
||||
# [begin_idx, begin_block_idx, end_idx, end_block_idx,
|
||||
# begin_n_split_idx, _, _, _]
|
||||
tile_scheduler_metadata = torch.zeros((1, 8), dtype=dtype, device=device)
|
||||
tile_scheduler_metadata[0, 0] = 0 # begin_idx
|
||||
tile_scheduler_metadata[0, 1] = 0 # sched_begin_block_idx
|
||||
tile_scheduler_metadata[0, 2] = B - 1 # end_idx
|
||||
tile_scheduler_metadata[0, 3] = end_block_idx
|
||||
tile_scheduler_metadata[0, 4] = 0 # begin_n_split_idx
|
||||
# fields [5..7] stay 0
|
||||
|
||||
# Non-split path ignores num_splits, but the API requires it:
|
||||
# zeros of length B+1
|
||||
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
|
||||
|
||||
o, lse = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
cache_seqlens=attn_metadata.decode.seq_lens,
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=tile_scheduler_metadata,
|
||||
num_splits=num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
descale_q=layer._q_scale.reshape(1),
|
||||
descale_k=layer._k_scale.reshape(1),
|
||||
)
|
||||
|
||||
o = reshape_attn_output_for_spec_decode(o)
|
||||
|
||||
return o, lse
|
||||
560
v1/attention/backends/mla/flashmla_sparse.py
Normal file
560
v1/attention/backends/mla/flashmla_sparse.py
Normal file
@@ -0,0 +1,560 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, ClassVar, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionLayer,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.backends.utils import get_mla_dims
|
||||
from vllm.attention.ops.flashmla import (
|
||||
flash_mla_sparse_prefill,
|
||||
flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.deepseek_v2 import Indexer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
"""
|
||||
NOTE: FlashMLA Sparse uses an fp8 cache with the following format
|
||||
|
||||
In the "FP8 with scale" format, each token's KV cache is 656 Bytes,
|
||||
structured as:
|
||||
- **First 512 bytes:** The "quantized NoPE" part, containing 512
|
||||
`float8_e4m3` values.
|
||||
- **Next 16 bytes:** Scale factors, containing 4 `float32` values.
|
||||
The first `float32` is the scale for the first 128 `float8_e4m3` values,
|
||||
the second for the next 128, and so on.
|
||||
- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This
|
||||
part is not quantized for accuracy.
|
||||
"""
|
||||
|
||||
|
||||
class FlashMLASparseBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_SPARSE"
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]:
|
||||
return FlashMLASparseMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["FlashMLASparseImpl"]:
|
||||
return FlashMLASparseImpl
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [576]
|
||||
|
||||
@classmethod
|
||||
def is_mla(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_sparse(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return capability.major in [9, 10]
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if cache_dtype_str == "fp8_ds_mla":
|
||||
# custom storage fromat is 656 bytes
|
||||
# see FlashMLA readme.md for details
|
||||
return (num_blocks, block_size, 656)
|
||||
else:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadata:
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
block_table: torch.Tensor
|
||||
req_id_per_token: torch.Tensor
|
||||
block_size: int = 64
|
||||
topk_tokens: int = 2048
|
||||
|
||||
@dataclass
|
||||
class FP8KernelMetadata:
|
||||
scheduler_metadata: torch.Tensor | None
|
||||
num_splits: torch.Tensor
|
||||
dummy_block_table: torch.Tensor
|
||||
cache_lens: torch.Tensor
|
||||
|
||||
fp8_extra_metadata: FP8KernelMetadata | None = None
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _convert_req_index_to_global_index_kernel(
|
||||
req_id_ptr, # int32 [num_tokens]
|
||||
block_table_ptr, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
# shapes (compile-time where possible)
|
||||
max_num_blocks_per_req: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr, # tile width along columns
|
||||
# strides (in elements)
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
):
|
||||
# program_id(0) -> token_id (row)
|
||||
# program_id(1) -> tile index along columns
|
||||
token_id = tl.program_id(0)
|
||||
tile_id = tl.program_id(1)
|
||||
|
||||
# Each program covers BLOCK_N consecutive columns
|
||||
indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
|
||||
# Load request id for this token (no mask: grid is exact)
|
||||
req = tl.load(req_id_ptr + token_id)
|
||||
|
||||
# Load token indices for this tile
|
||||
ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1
|
||||
tok = tl.load(ti_ptr) # int32
|
||||
|
||||
# Only token == -1 should propagate as -1
|
||||
is_invalid_tok = tok < 0
|
||||
|
||||
# Compute block id and in-block offset
|
||||
block_id = tok // BLOCK_SIZE
|
||||
inblock_off = tok % BLOCK_SIZE
|
||||
|
||||
# Guard block_table access
|
||||
valid_block = block_id < max_num_blocks_per_req
|
||||
bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1
|
||||
base = tl.load(bt_ptr, mask=valid_block, other=0)
|
||||
|
||||
# If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset
|
||||
out_val = tl.where(
|
||||
is_invalid_tok | (~valid_block), -1, base * BLOCK_SIZE + inblock_off
|
||||
)
|
||||
|
||||
# Store results
|
||||
out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1
|
||||
tl.store(out_ptr_ij, out_val)
|
||||
|
||||
|
||||
def triton_convert_req_index_to_global_index(
|
||||
req_id: torch.Tensor, # int32 [num_tokens]
|
||||
block_table: torch.Tensor, # int32 [num_requests, max_num_blocks_per_req]
|
||||
token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS]
|
||||
BLOCK_SIZE: int = 64,
|
||||
NUM_TOPK_TOKENS: int = 2048,
|
||||
BLOCK_N: int = 128, # tile width along columns
|
||||
):
|
||||
"""
|
||||
out[token_id, indice_id] =
|
||||
block_table[req_id[token_id],
|
||||
token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE
|
||||
+ token_indices[token_id, indice_id] % BLOCK_SIZE
|
||||
|
||||
Only when token_indices[token_id, indice_id] == -1 do we output -1.
|
||||
For safety, we also output -1 if the derived block_id would be
|
||||
out-of-bounds.
|
||||
"""
|
||||
assert req_id.dtype == torch.int32
|
||||
assert block_table.dtype == torch.int32
|
||||
assert token_indices.dtype == torch.int32
|
||||
assert token_indices.shape[1] == NUM_TOPK_TOKENS
|
||||
assert NUM_TOPK_TOKENS % BLOCK_N == 0, (
|
||||
f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible byBLOCK_N ({BLOCK_N})"
|
||||
)
|
||||
|
||||
num_tokens = req_id.shape[0]
|
||||
num_requests, max_num_blocks_per_req = block_table.shape
|
||||
tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N
|
||||
|
||||
# Ensure contiguous tensors on the same device
|
||||
req_id_c = req_id.contiguous()
|
||||
block_table_c = block_table.contiguous()
|
||||
token_indices_c = token_indices.contiguous()
|
||||
out = torch.empty_like(token_indices_c)
|
||||
|
||||
# Strides in elements
|
||||
bt_stride0, bt_stride1 = block_table_c.stride()
|
||||
ti_stride0, ti_stride1 = token_indices_c.stride()
|
||||
out_stride0, out_stride1 = out.stride()
|
||||
|
||||
# Exact 2D grid: tokens × column tiles
|
||||
grid = (num_tokens, tiles_per_row)
|
||||
|
||||
_convert_req_index_to_global_index_kernel[grid](
|
||||
req_id_c,
|
||||
block_table_c,
|
||||
token_indices_c,
|
||||
out,
|
||||
# shapes / constexprs
|
||||
max_num_blocks_per_req,
|
||||
BLOCK_SIZE,
|
||||
BLOCK_N,
|
||||
# strides
|
||||
bt_stride0,
|
||||
bt_stride1,
|
||||
ti_stride0,
|
||||
ti_stride1,
|
||||
out_stride0,
|
||||
out_stride1,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
cache_config = vllm_config.cache_config
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.model_config = vllm_config.model_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self.device = device
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
sm_count = props.multi_processor_count
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
self.topk_tokens = vllm_config.model_config.hf_config.index_topk
|
||||
self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla"
|
||||
self.topk_tokens_tensor = torch.tensor(
|
||||
[self.topk_tokens], device=device, dtype=torch.int32
|
||||
)
|
||||
self.max_model_len_tensor = torch.tensor(
|
||||
[self.model_config.max_model_len], device=device, dtype=torch.int32
|
||||
)
|
||||
# this is ignored by `flash_mla_with_kvcache` if indices not None
|
||||
self.dummy_block_table = torch.empty(
|
||||
(1, 1), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# Equation taken from FlashMLA/csrc/pybind.cpp
|
||||
h_q, h_k = self.num_heads, 1
|
||||
s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest
|
||||
max_num_sm_parts = int(
|
||||
max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)
|
||||
)
|
||||
if current_platform.is_device_capability(100):
|
||||
max_num_sm_parts *= 2
|
||||
self.tile_scheduler_metadata_buffer = torch.empty(
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
# see: FlashMLA/csrc/params.h
|
||||
(max_num_sm_parts, 8),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.num_splits_buffer = torch.empty(
|
||||
# We pack all the tokens into one batch for sparse attention.
|
||||
# Otherwise, we can exceed the sm of `get_mla_metadata`.
|
||||
(2,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.req_id_per_token_buffer = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_batched_tokens,),
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> FlashMLASparseMetadata:
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
starts = np.asarray(common_attn_metadata.query_start_loc_cpu, dtype=np.int32)
|
||||
seg_lengths = np.diff(starts)
|
||||
req_id_per_token = np.repeat(
|
||||
np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths
|
||||
)
|
||||
# Zero-fill for cudagraphs
|
||||
self.req_id_per_token_buffer.fill_(0)
|
||||
self.req_id_per_token_buffer[: req_id_per_token.shape[0]].copy_(
|
||||
torch.from_numpy(req_id_per_token), non_blocking=True
|
||||
)
|
||||
req_id_per_token = self.req_id_per_token_buffer[:num_tokens]
|
||||
|
||||
fp8_extra_metadata = None
|
||||
if self.use_fp8_kv_cache:
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
cache_seqlens=self.topk_tokens_tensor,
|
||||
num_q_tokens_per_head_k=num_tokens * self.num_heads,
|
||||
topk=self.topk_tokens,
|
||||
num_heads_q=self.num_heads,
|
||||
num_heads_k=1,
|
||||
is_fp8_kvcache=True,
|
||||
)
|
||||
|
||||
num_sm_parts = tile_scheduler_metadata.size(0)
|
||||
# Copy to persistent buffer for full-CG support
|
||||
tile_scheduler_metadata_buffer = self.tile_scheduler_metadata_buffer[
|
||||
:num_sm_parts
|
||||
]
|
||||
tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata)
|
||||
self.num_splits_buffer.copy_(num_splits)
|
||||
|
||||
fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata(
|
||||
scheduler_metadata=tile_scheduler_metadata_buffer,
|
||||
num_splits=self.num_splits_buffer,
|
||||
# cache_lens and block_table are basically unused in sparse case
|
||||
# but the decode kernel will treat -1 and indices >= cache_lens
|
||||
# as invalid so we make sure cache_lens is large enough to not
|
||||
# accidentally mark indices invalid, we will use -1 exclusively
|
||||
# to mark invalid indices
|
||||
cache_lens=self.max_model_len_tensor,
|
||||
dummy_block_table=self.dummy_block_table,
|
||||
)
|
||||
|
||||
metadata = FlashMLASparseMetadata(
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
req_id_per_token=req_id_per_token,
|
||||
block_size=self.kv_cache_spec.block_size,
|
||||
topk_tokens=self.topk_tokens,
|
||||
fp8_extra_metadata=fp8_extra_metadata,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
||||
class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
topk_indice_buffer: torch.Tensor | None = None,
|
||||
indexer: Optional["Indexer"] = None,
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
self.softmax_scale = scale
|
||||
assert indexer is not None
|
||||
self.topk_indices_buffer = indexer.topk_indices_buffer
|
||||
self.padding = 128 if current_platform.is_device_capability(100) else 64
|
||||
|
||||
def _forward_bf16_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = q.shape[0]
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
|
||||
-1, 1, kv_c_and_k_pe_cache.shape[-1]
|
||||
)
|
||||
|
||||
# NOTE(Chen): kernel requires num_local_head to be a multiple of
|
||||
# 64 on hopper and 128 on blackwell
|
||||
if self.num_heads % self.padding != 0:
|
||||
assert self.padding % self.num_heads == 0
|
||||
logger.warning_once(
|
||||
f"padding num_heads to {self.padding} \
|
||||
due to sparse attn kernel requirement"
|
||||
)
|
||||
q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2]))
|
||||
q_padded[:, : self.num_heads, :] = q
|
||||
q = q_padded
|
||||
|
||||
topk_indices = topk_indices.view(num_tokens, 1, -1)
|
||||
output = flash_mla_sparse_prefill(
|
||||
q, kv_c_and_k_pe_cache, topk_indices, self.softmax_scale
|
||||
)
|
||||
output = output[:, : self.num_heads, :]
|
||||
return output
|
||||
|
||||
def _forward_fp8_kv(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
topk_indices: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert attn_metadata.fp8_extra_metadata is not None
|
||||
extra_metadata = attn_metadata.fp8_extra_metadata
|
||||
|
||||
_attn_out, _ = flash_mla_with_kvcache(
|
||||
q=q.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2),
|
||||
block_table=extra_metadata.dummy_block_table,
|
||||
head_dim_v=512,
|
||||
cache_seqlens=extra_metadata.cache_lens,
|
||||
tile_scheduler_metadata=extra_metadata.scheduler_metadata,
|
||||
num_splits=extra_metadata.num_splits,
|
||||
is_fp8_kvcache=True,
|
||||
indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim
|
||||
softmax_scale=self.softmax_scale,
|
||||
)
|
||||
|
||||
return _attn_out
|
||||
|
||||
def forward_prepare(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
) -> None:
|
||||
self.positions = positions
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLASparseMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
kv_cache_scale: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
|
||||
# MQA 576/512 approach for both prefill and decode
|
||||
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for MLACommonImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# The zero fill is required when used with DP + EP
|
||||
# to ensure all ranks within a DP group compute the
|
||||
# same expert outputs.
|
||||
output = torch.empty(output.shape[0], self.v_head_dim * self.num_heads, device=q.device,
|
||||
dtype=q.dtype)
|
||||
return output
|
||||
|
||||
num_actual_toks = attn_metadata.num_actual_tokens
|
||||
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
q = q[:num_actual_toks, ...]
|
||||
k_c_normed = k_c_normed[:num_actual_toks, ...]
|
||||
k_pe = k_pe[:num_actual_toks, ...]
|
||||
|
||||
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
q_pe, k_pe = self.rotary_emb(self.positions[:num_actual_toks], q_pe, k_pe)
|
||||
|
||||
q_nope = self._k_up_proj(q_nope)
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
topk_indices = self.topk_indices_buffer[:num_actual_toks]
|
||||
|
||||
# TODO: handle index / kv_cache correctly
|
||||
topk_indices_global = triton_convert_req_index_to_global_index(
|
||||
attn_metadata.req_id_per_token,
|
||||
attn_metadata.block_table,
|
||||
topk_indices,
|
||||
BLOCK_SIZE=attn_metadata.block_size,
|
||||
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
|
||||
)
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe,
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype != "fp8_ds_mla":
|
||||
attn_out = self._forward_bf16_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
else:
|
||||
attn_out = self._forward_fp8_kv(
|
||||
q, kv_cache, topk_indices_global, attn_metadata
|
||||
)
|
||||
output = torch.empty(output.shape[0],
|
||||
self.num_heads, self.v_head_dim,
|
||||
device=q.device,
|
||||
dtype=q.dtype)
|
||||
|
||||
output[:num_actual_toks] = self._v_up_proj(attn_out)
|
||||
return output.view(output.shape[0], self.v_head_dim * self.num_heads)
|
||||
362
v1/attention/backends/mla/indexer.py
Normal file
362
v1/attention/backends/mla/indexer.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV32IndexerBackend(AttentionBackend):
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 128]
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]:
|
||||
return DeepseekV32IndexerMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
assert num_kv_heads == 1
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_stride_order() -> tuple[int, ...]:
|
||||
return (0, 1, 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
cu_seqlen_ks: torch.Tensor
|
||||
cu_seqlen_ke: torch.Tensor
|
||||
cu_seq_lens: torch.Tensor
|
||||
total_seq_lens: int
|
||||
token_start: int
|
||||
token_end: int
|
||||
num_reqs: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
decode_lens: torch.Tensor
|
||||
requires_padding: bool
|
||||
# schedule_metadata: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerMetadata:
|
||||
# FIXME (zyongye)
|
||||
# hacky way to access the data now, need to be in chunked meta
|
||||
seq_lens: torch.Tensor
|
||||
|
||||
num_reqs: int
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
query_start_loc: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
# The dimension of the attention heads
|
||||
head_dim: int
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
|
||||
decode: DeepSeekV32IndexerDecodeMetadata | None = None
|
||||
prefill: DeepseekV32IndexerPrefillMetadata | None = None
|
||||
|
||||
|
||||
# TODO (zyongye) optimize this, this is now vibe coded
|
||||
def kv_spans_from_batches(
|
||||
start_seq_loc: torch.Tensor, seq_len_per_batch: torch.Tensor, device: torch.device
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
start_seq_loc: 1D long tensor [B+1], cumulative counts of
|
||||
selected tokens per batch.
|
||||
Example: [0, 2, 4, 7] ->
|
||||
batch sizes (selected) [2, 2, 3], N=7 tokens total.
|
||||
seq_len_per_batch: 1D long tensor [B],
|
||||
full sequence length (KV length) of each batch.
|
||||
Example: [5, 9, 4].
|
||||
|
||||
Returns:
|
||||
start_tensor: 1D long tensor [N], start offset in the
|
||||
concatenated KV cache for each token's batch.
|
||||
end_location: 1D long tensor [N],
|
||||
**exclusive** end = start + token's local position.
|
||||
(So the attended KV slice is kv[start:end].)
|
||||
|
||||
Assumes each batch contributes its full `seq_len_per_batch[i]`
|
||||
keys to the KV cache, andthe selected tokens within a batch
|
||||
are the **last** `counts[i]` positions of that sequence.
|
||||
"""
|
||||
q = start_seq_loc.to(dtype=torch.long)
|
||||
L = seq_len_per_batch.to(dtype=torch.long)
|
||||
assert q.dim() == 1 and L.dim() == 1
|
||||
assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1"
|
||||
|
||||
# Selected tokens per batch and totals
|
||||
counts = q[1:] - q[:-1] # [B]
|
||||
N = int(q[-1].item()) # total selected tokens
|
||||
B = L.numel()
|
||||
|
||||
if N == 0:
|
||||
return (
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
torch.empty(0, dtype=torch.long, device=device),
|
||||
)
|
||||
|
||||
# KV start offsets per batch in the concatenated KV cache
|
||||
kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B]
|
||||
|
||||
# For each selected token, which batch does it belong to?
|
||||
batch_id = torch.repeat_interleave(torch.arange(B), counts) # [N]
|
||||
|
||||
# Map batch KV start to each token
|
||||
start_tensor = kv_starts_per_batch[batch_id] # [N]
|
||||
|
||||
# End-align local positions inside each batch:
|
||||
# local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b
|
||||
L_expand = torch.repeat_interleave(L, counts) # [N]
|
||||
m_expand = torch.repeat_interleave(counts, counts) # [N]
|
||||
# position within the selected block: 1..counts[b]
|
||||
pos_within = (
|
||||
torch.arange(N, dtype=torch.long) - torch.repeat_interleave(q[:-1], counts) + 1
|
||||
)
|
||||
|
||||
local_pos = L_expand - m_expand + pos_within # [N], 1-based
|
||||
end_location = start_tensor + local_pos # exclusive end
|
||||
|
||||
return start_tensor.int().to(device), end_location.int().to(device)
|
||||
|
||||
|
||||
def get_max_prefill_buffer_size(vllm_config: VllmConfig):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
# NOTE(Chen): 2 is a magic number for controlling the prefill buffer size.
|
||||
# May be tuned later.
|
||||
return max_model_len * 2
|
||||
|
||||
|
||||
def split_prefill_chunks(
|
||||
seq_lens_cpu: torch.Tensor, max_prefill_buffer_size: int, reqs_start: int
|
||||
) -> list[tuple[int, int]]:
|
||||
"""
|
||||
Split the prefill chunks into a list of tuples of (reqs_start, reqs_end)
|
||||
such that the total sequence length of each chunk is less than the
|
||||
maximum prefill buffer size.
|
||||
|
||||
Args:
|
||||
seq_lens_cpu: The sequence lengths of the prefill requests.
|
||||
max_prefill_buffer_size: The maximum prefill buffer size.
|
||||
reqs_start: The start index of the prefill requests.
|
||||
|
||||
Returns:
|
||||
A list of tuples of (reqs_start, reqs_end).
|
||||
"""
|
||||
chunk_seq_ids = []
|
||||
total_seq_lens = 0
|
||||
for i in range(reqs_start, len(seq_lens_cpu)):
|
||||
cur_seq_len = seq_lens_cpu[i].item()
|
||||
assert cur_seq_len <= max_prefill_buffer_size
|
||||
total_seq_lens += cur_seq_len
|
||||
if total_seq_lens > max_prefill_buffer_size:
|
||||
chunk_seq_ids.append((reqs_start, i))
|
||||
reqs_start = i
|
||||
total_seq_lens = cur_seq_len
|
||||
if total_seq_lens > 0:
|
||||
chunk_seq_ids.append((reqs_start, len(seq_lens_cpu)))
|
||||
return chunk_seq_ids
|
||||
|
||||
|
||||
class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
scheduler_config = self.vllm_config.scheduler_config
|
||||
# NOTE(Chen):an estimated max size of flattened_kv. Need to double check.
|
||||
self.max_prefill_buffer_size = get_max_prefill_buffer_size(self.vllm_config)
|
||||
self.num_speculative_tokens = (
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config
|
||||
else 0
|
||||
)
|
||||
# Now deepgemm fp8_paged_mqa_logits does not support next_n > 2
|
||||
self.reorder_batch_threshold += min(self.num_speculative_tokens, 1)
|
||||
|
||||
props = torch.cuda.get_device_properties(self.device)
|
||||
sm_count = props.multi_processor_count
|
||||
self.num_sms = sm_count
|
||||
|
||||
self.decode_lens_buffer = torch.empty(
|
||||
(scheduler_config.max_num_seqs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
# See: DeepGMM/csrc/apis/attention.hpp
|
||||
self.scheduler_metadata_buffer = torch.empty(
|
||||
(self.num_sms + 1, 2), dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
def build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1]
|
||||
- query_start_loc_cpu[reqs_start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
return DeepseekV32IndexerPrefillChunkMetadata(
|
||||
cu_seqlen_ks=cu_seqlen_ks,
|
||||
cu_seqlen_ke=cu_seqlen_ke,
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
total_seq_lens=total_seq_lens,
|
||||
block_table=block_table[reqs_start:reqs_end],
|
||||
token_start=token_start,
|
||||
token_end=token_end,
|
||||
num_reqs=reqs_end - reqs_start,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> DeepseekV32IndexerMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(
|
||||
common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes],
|
||||
)
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
# self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata(
|
||||
# seq_lens, self.kv_cache_spec.block_size, self.num_sms
|
||||
# )
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
# schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
294
v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
294
v1/attention/backends/mla/rocm_aiter_mla.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.abstract import AttentionLayer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
class AiterMLABackend(MLACommonBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterMLAImpl"]:
|
||||
return AiterMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterMLAMetadataBuilder"]:
|
||||
return AiterMLAMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: torch.Tensor | None = None
|
||||
# The page indices of the paged kv cache
|
||||
paged_kv_indices: torch.Tensor | None = None
|
||||
# The number of entries in the last page of each request in
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_len: torch.Tensor | None = None
|
||||
# The query indptr, shape : [num_decode + 1]
|
||||
qo_indptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
# TODO(luka, lucas): audit this as part of:
|
||||
# https://github.com/vllm-project/vllm/issues/22945
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = (
|
||||
AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_spec, layer_names, vllm_config, device, AiterMLAMetadata
|
||||
)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
max_num_pages_per_req = cdiv(
|
||||
vllm_config.model_config.max_model_len, self.kv_cache_spec.block_size
|
||||
)
|
||||
max_num_reqs = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_pages = max_num_reqs * max_num_pages_per_req
|
||||
|
||||
# Preparing persistent buffers
|
||||
# TODO: we can disambiguate between decode and mixed-prefill decode here
|
||||
# so we can only use the persistent buffer if a cudagraph is actually
|
||||
# being used.
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.block_table_remapping = torch.zeros(
|
||||
[max_num_reqs, max_num_pages_per_req * self.kv_cache_spec.block_size],
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_indices = torch.zeros(
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
self.paged_kv_last_page_len = torch.zeros(
|
||||
max_num_reqs, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
self.qo_indptr = torch.arange(
|
||||
0, max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
def _build_decode(
|
||||
self,
|
||||
block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int,
|
||||
dcp_tot_seq_lens_device: torch.Tensor | None,
|
||||
) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
device = self.device
|
||||
num_reqs = seq_lens_device.size(0)
|
||||
bs, _ = block_table_tensor.shape
|
||||
block_table_tensor = (
|
||||
block_table_tensor.unsqueeze(-1).expand(-1, -1, page_size) * page_size
|
||||
)
|
||||
block_table_tensor = (
|
||||
block_table_tensor
|
||||
+ torch.arange(
|
||||
0,
|
||||
page_size,
|
||||
device=block_table_tensor.device,
|
||||
dtype=block_table_tensor.dtype,
|
||||
)[None, None, :]
|
||||
)
|
||||
block_table_tensor = block_table_tensor.view(bs, -1)
|
||||
|
||||
# after remapping, we assume the block size already equals to 1
|
||||
|
||||
max_blk_size_per_req = block_table_tensor.shape[-1]
|
||||
mask = torch.arange(
|
||||
block_table_tensor.size(1), dtype=block_table_tensor.dtype, device=device
|
||||
).unsqueeze(0) < seq_lens_device.unsqueeze(1)
|
||||
paged_kv_indices = block_table_tensor[mask]
|
||||
|
||||
paged_kv_last_page_len = seq_lens_device % page_size
|
||||
paged_kv_last_page_len = torch.where(
|
||||
paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len
|
||||
)
|
||||
|
||||
paged_kv_indptr = torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=seq_lens_device.dtype, device=device),
|
||||
seq_lens_device.cumsum(dim=0, dtype=torch.int32),
|
||||
]
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
num_actual_pages = paged_kv_indices.size(0)
|
||||
self.block_table_remapping[:num_reqs, :max_blk_size_per_req].copy_(
|
||||
block_table_tensor, non_blocking=True
|
||||
)
|
||||
block_table_tensor = self.block_table_remapping[
|
||||
:num_reqs, :max_blk_size_per_req
|
||||
]
|
||||
|
||||
self.paged_kv_indices[:num_actual_pages].copy_(
|
||||
paged_kv_indices, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indices[num_actual_pages:].fill_(-1)
|
||||
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
|
||||
|
||||
self.paged_kv_indptr[: 1 + num_reqs].copy_(
|
||||
paged_kv_indptr, non_blocking=True
|
||||
)
|
||||
self.paged_kv_indptr[1 + num_reqs :].fill_(paged_kv_indptr[-1])
|
||||
paged_kv_indptr = self.paged_kv_indptr[: 1 + num_reqs]
|
||||
|
||||
self.paged_kv_last_page_len[:num_reqs].copy_(
|
||||
paged_kv_last_page_len, non_blocking=True
|
||||
)
|
||||
self.paged_kv_last_page_len[num_reqs:].fill_(1)
|
||||
paged_kv_last_page_len = self.paged_kv_last_page_len[:num_reqs]
|
||||
|
||||
qo_indptr = self.qo_indptr[: 1 + num_reqs]
|
||||
|
||||
else:
|
||||
qo_indptr = torch.arange(
|
||||
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
paged_kv_indptr=paged_kv_indptr,
|
||||
paged_kv_indices=paged_kv_indices,
|
||||
paged_kv_last_page_len=paged_kv_last_page_len,
|
||||
qo_indptr=qo_indptr,
|
||||
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
assert num_heads == 16 or num_heads == 128, (
|
||||
f"Aiter MLA only supports 16 or 128 number of heads.\n"
|
||||
f"Provided {num_heads} number of heads.\n"
|
||||
"Try adjusting tensor_parallel_size value."
|
||||
)
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"Aiter MLA does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
self.flash_attn_varlen_func = flash_attn_varlen_func
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: AiterMLAMetadata,
|
||||
layer: AttentionLayer,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if type(q) is tuple:
|
||||
q = torch.cat(q, dim=-1)
|
||||
|
||||
assert isinstance(q, torch.Tensor)
|
||||
B = q.shape[0]
|
||||
o = torch.zeros(
|
||||
B, self.num_heads, self.kv_lora_rank, dtype=q.dtype, device=q.device
|
||||
)
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
rocm_aiter_ops.mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer,
|
||||
o,
|
||||
self.scale,
|
||||
attn_metadata.decode.qo_indptr,
|
||||
max_seqlen_qo,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
attn_metadata.decode.paged_kv_indices,
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
)
|
||||
|
||||
return o, None
|
||||
206
v1/attention/backends/mla/triton_mla.py
Normal file
206
v1/attention/backends/mla/triton_mla.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_dcp_group
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
import ixformer.inference.functions as ixf_ops
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TritonMLABackend(MLACommonBackend):
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None,
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: str | None,
|
||||
# MLA Specific Arguments
|
||||
**mla_args,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**mla_args,
|
||||
)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, logits_soft_cap"
|
||||
)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl"
|
||||
)
|
||||
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"TritonMLA V1 with FP8 KV cache not yet supported"
|
||||
)
|
||||
|
||||
def _flash_attn_varlen_diff_headdims(
|
||||
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
|
||||
):
|
||||
return super()._flash_attn_varlen_diff_headdims(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
softmax_scale=softmax_scale,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: MLACommonMetadata,
|
||||
k_c_normed: torch.Tensor | None,
|
||||
k_pe: torch.Tensor | None,
|
||||
kv_c_and_k_pe_cache_scale: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
assert attn_metadata.decode is not None
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode
|
||||
q_nope = self._k_up_proj(q_nope)
|
||||
q_nope = q_nope.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
B = q_nope.shape[0]
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q = get_dcp_group().all_gather(q, dim=1)
|
||||
o = torch.empty(B,
|
||||
q.shape[1],
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
return_softmax_lse=True
|
||||
)
|
||||
else:
|
||||
attn_out, softmax_lse = ixf_ops.vllm_paged_attention_mla(
|
||||
output=o,
|
||||
query=q,
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
return_softmax_lse=True)
|
||||
return attn_out, softmax_lse
|
||||
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q_nope.dtype,
|
||||
device=q_nope.device)
|
||||
|
||||
if envs.VLLM_USE_INT8_MLA:
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
q_int8, q_scale = ops.quant_kv(q)
|
||||
ixf_ops.vllm_paged_attention_mla_int8(
|
||||
o,
|
||||
q_int8,
|
||||
q_scale,
|
||||
kv_c_and_k_pe_cache,
|
||||
kv_c_and_k_pe_cache_scale,
|
||||
self.scale,
|
||||
attn_metadata.decode.block_table,
|
||||
attn_metadata.decode.seq_lens,
|
||||
attn_metadata.decode.max_decode_seq_len,
|
||||
attn_metadata.decode.use_cuda_graph
|
||||
)
|
||||
else:
|
||||
# fused q concat & cache write
|
||||
ixf_ops.vllm_paged_attention_mla_fused(
|
||||
output=o,
|
||||
q_nope=q_nope,
|
||||
q_pe=q_pe.contiguous(),
|
||||
kv_cache=kv_c_and_k_pe_cache,
|
||||
scale=self.scale,
|
||||
block_tables=attn_metadata.decode.block_table,
|
||||
context_lens=attn_metadata.decode.seq_lens,
|
||||
max_context_len=decode_meta.max_decode_seq_len,
|
||||
k_c_normed=k_c_normed,
|
||||
k_pe=k_pe,
|
||||
use_cuda_graph=decode_meta.use_cuda_graph
|
||||
)
|
||||
return self._v_up_proj(o), None
|
||||
436
v1/attention/backends/pallas.py
Normal file
436
v1/attention/backends/pallas.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TPU requires the head size to be a multiple of 128.
|
||||
TPU_HEAD_SIZE_ALIGNMENT = 128
|
||||
|
||||
# Note: TPU can fp8 as storage dtype but doesn't support converting from uint8
|
||||
# from to fp32 directly. That's why it has a dtype mapping different from GPU
|
||||
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.half,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
"fp8": torch.float8_e4m3fn,
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
"int8": torch.int8,
|
||||
"uint8": torch.uint8,
|
||||
}
|
||||
|
||||
try:
|
||||
import tpu_inference # noqa: F401
|
||||
except ImportError:
|
||||
# Lazy import torch_xla
|
||||
import torch_xla.core.xla_builder as xb
|
||||
import torch_xla.experimental.custom_kernel # noqa: F401
|
||||
from torch.library import impl
|
||||
from torch_xla._internal.jax_workarounds import requires_jax
|
||||
from torch_xla.experimental.custom_kernel import XLA_LIB
|
||||
|
||||
@requires_jax
|
||||
def kv_cache_update_op_impl(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
):
|
||||
from vllm.attention.ops.pallas_kv_cache_update import kv_cache_update
|
||||
|
||||
new_kv_cache = xb.call_jax(
|
||||
kv_cache_update,
|
||||
(kv, slot_mapping, kv_cache, num_kv_update_slices),
|
||||
{"page_size": page_size, "num_slices_per_block": num_slices_per_block},
|
||||
)
|
||||
return new_kv_cache
|
||||
|
||||
XLA_LIB.define(
|
||||
"kv_cache_update_op(Tensor kv, Tensor slot_mapping,"
|
||||
"Tensor kv_cache, Tensor num_kv_update_slices, int page_size,"
|
||||
"int num_slices_per_block)"
|
||||
"-> Tensor",
|
||||
)
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "XLA")
|
||||
def kv_cache_update_op_xla(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
) -> torch.Tensor:
|
||||
new_kv_cache = kv_cache_update_op_impl(
|
||||
kv,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
num_kv_update_slices,
|
||||
page_size,
|
||||
num_slices_per_block,
|
||||
)
|
||||
return new_kv_cache
|
||||
|
||||
@impl(XLA_LIB, "kv_cache_update_op", "CompositeExplicitAutograd")
|
||||
def kv_cache_update_op_non_xla(
|
||||
kv: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
page_size: int,
|
||||
num_slices_per_block: int,
|
||||
) -> torch.Tensor:
|
||||
return kv_cache
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
padded_head_size = (
|
||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
return (num_blocks, block_size, num_kv_heads * 2, padded_head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
# In recent TPU generations, up to v6e, the SMEM size is 1MB. The
|
||||
# block_tables within the PallasMetadata constitute almost the entire SMEM
|
||||
# requirement. Its size is max_num_seqs * num_page_per_seq * 4 (Int). Here
|
||||
# we simply make sure that the size is smaller than half of SMEM capacity.
|
||||
@staticmethod
|
||||
def get_min_page_size(vllm_config: VllmConfig) -> int:
|
||||
max_num_page_per_req = (
|
||||
1024 * 1024 // 2 // vllm_config.scheduler_config.max_num_seqs // 4
|
||||
)
|
||||
min_page_size = cdiv(
|
||||
vllm_config.model_config.max_model_len, max_num_page_per_req
|
||||
)
|
||||
min_page_size = 1 << (min_page_size - 1).bit_length()
|
||||
return min_page_size
|
||||
|
||||
@staticmethod
|
||||
def get_max_num_seqs(model_len: int, page_size: int) -> int:
|
||||
num_page_per_req = cdiv(model_len, page_size)
|
||||
return 1024 * 1024 // 2 // num_page_per_req // 4
|
||||
|
||||
# TPU has limited SREGs (scalar registers), if page_size is too small, we
|
||||
# can spill SREGs easily which leads to bad performance. The strategy we
|
||||
# apply here is trying to split max-model-len to 16 pages which make the
|
||||
# spill less likely. Meanwhile we make sure the page size is in [16, 256].
|
||||
@staticmethod
|
||||
def get_page_size(vllm_config: VllmConfig) -> int:
|
||||
# TODO: This is a temporary fix for vmem OOM.
|
||||
# For long model length, we use 16 page-size to avoid too much
|
||||
# VMEM spill. A more robust solution should be implemented to
|
||||
# handle VREG spills.
|
||||
if vllm_config.model_config.max_model_len > 8192:
|
||||
return 16
|
||||
page_size = next_power_of_2(vllm_config.model_config.max_model_len) // 16
|
||||
if page_size <= 16:
|
||||
return 16
|
||||
if page_size >= 256:
|
||||
return 256
|
||||
return page_size
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Used in the PallasAttentionBackendImpl
|
||||
slot_mapping: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
context_lens: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
num_seqs: torch.Tensor
|
||||
num_kv_update_slices: torch.Tensor
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl"
|
||||
)
|
||||
|
||||
self.kv_cache_quantized_dtype = None
|
||||
if kv_cache_dtype != "auto":
|
||||
self.kv_cache_quantized_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE.get(
|
||||
kv_cache_dtype.lower().strip()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: PallasMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache: shape =
|
||||
[num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionBackendImpl"
|
||||
)
|
||||
|
||||
# For determine_available_memory case.
|
||||
if kv_cache.numel() == 0:
|
||||
if output is None:
|
||||
output = torch.ones_like(query)
|
||||
return output
|
||||
|
||||
num_tokens, hidden_size = query.shape
|
||||
query = query.view(num_tokens, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
padded_head_size = (
|
||||
cdiv(self.head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
query = torch.nn.functional.pad(
|
||||
query, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
key = torch.nn.functional.pad(
|
||||
key, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
value = torch.nn.functional.pad(
|
||||
value, (0, padded_head_size - self.head_size), value=0.0
|
||||
)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None and kv_cache.numel() > 0:
|
||||
# Write input keys and values to the KV cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
write_to_kv_cache(
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
slot_mapping,
|
||||
attn_metadata.num_slices_per_kv_cache_update_block,
|
||||
attn_metadata.num_kv_update_slices,
|
||||
self.kv_cache_quantized_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.kv_cache_quantized_dtype is not None and (
|
||||
layer._k_scale_float == 0.0 or layer._v_scale_float == 0.0
|
||||
):
|
||||
raise ValueError("k_scale_float and v_scale_float must be non-zero")
|
||||
output = torch.ops.xla.ragged_paged_attention(
|
||||
query,
|
||||
kv_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.query_start_loc,
|
||||
attn_metadata.num_seqs,
|
||||
# By default, the system utilizes optimized block size and
|
||||
# vmem_limit_bytes parameters from the kernel repository. However,
|
||||
# these can be manually adjusted for debugging if necessary.
|
||||
num_kv_pages_per_block=None,
|
||||
num_queries_per_block=None,
|
||||
vmem_limit_bytes=None,
|
||||
use_kernel=True,
|
||||
sm_scale=self.scale,
|
||||
sliding_window=self.sliding_window,
|
||||
soft_cap=self.logits_soft_cap,
|
||||
k_scale=layer._k_scale_float,
|
||||
v_scale=layer._v_scale_float,
|
||||
)
|
||||
|
||||
if self.head_size % TPU_HEAD_SIZE_ALIGNMENT != 0:
|
||||
output = output[:, :, : self.head_size]
|
||||
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
num_slices_per_kv_cache_update_block: int,
|
||||
num_kv_update_slices: torch.Tensor,
|
||||
kv_cache_quantized_dtype: torch.dtype | None = None,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
) -> None:
|
||||
"""Write the key and values to the KV cache.
|
||||
|
||||
Args:
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape = [num_blocks, block_size, num_kv_heads * 2, head_size]
|
||||
num_slices_per_kv_cache_update_block: int
|
||||
"""
|
||||
_, page_size, num_combined_kv_heads, head_size = kv_cache.shape
|
||||
head_size = cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
|
||||
if kv_cache_quantized_dtype is not None:
|
||||
dtype_info = torch.finfo(kv_cache_quantized_dtype)
|
||||
key = key.to(torch.float32) / k_scale
|
||||
# NOTE: clamp is added here to avoid out of range of quantized dtype
|
||||
key = torch.clamp(key, dtype_info.min, dtype_info.max)
|
||||
key = key.to(kv_cache_quantized_dtype)
|
||||
value = value.to(torch.float32) / v_scale
|
||||
value = torch.clamp(value, dtype_info.min, dtype_info.max)
|
||||
value = value.to(kv_cache_quantized_dtype)
|
||||
|
||||
kv = torch.cat([key, value], axis=-1).reshape(-1, num_combined_kv_heads, head_size)
|
||||
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(kv_cache, True)
|
||||
|
||||
kv_cache = kv_cache.flatten(0, 1)
|
||||
new_kv_cache = torch.ops.xla.kv_cache_update_op(
|
||||
kv,
|
||||
slot_mapping,
|
||||
kv_cache,
|
||||
num_kv_update_slices,
|
||||
page_size,
|
||||
num_slices_per_kv_cache_update_block,
|
||||
)
|
||||
# NOTE: the in-place copy will be optimized away by XLA compiler.
|
||||
kv_cache.copy_(new_kv_cache)
|
||||
|
||||
|
||||
# We can move this function to a common utils file if it's also useful for other
|
||||
# hardware.
|
||||
def dtype_bits(dtype: torch.dtype):
|
||||
if dtype.is_floating_point:
|
||||
try:
|
||||
return torch.finfo(dtype).bits
|
||||
except TypeError:
|
||||
pass
|
||||
elif dtype.is_complex:
|
||||
if dtype is torch.complex32:
|
||||
return 32
|
||||
elif dtype is torch.complex64:
|
||||
return 64
|
||||
elif dtype is torch.complex128:
|
||||
return 128
|
||||
else:
|
||||
try:
|
||||
return torch.iinfo(dtype).bits
|
||||
# torch.iinfo cannot support int4, int2, bits8...
|
||||
except TypeError:
|
||||
pass
|
||||
str_dtype = str(dtype)
|
||||
# support torch.int4, torch.int5, torch.uint5...
|
||||
if str_dtype.startswith("torch.int") or str_dtype.startswith("torch.uint"):
|
||||
return int(str_dtype[-1])
|
||||
raise TypeError(f"Getting the bit width of {dtype} is not supported")
|
||||
|
||||
|
||||
def get_dtype_packing(dtype):
|
||||
bits = dtype_bits(dtype)
|
||||
if 32 % bits != 0:
|
||||
raise ValueError(
|
||||
f"The bit width must be divisible by 32, but got bits={bits}, "
|
||||
"dtype={dtype}"
|
||||
)
|
||||
return 32 // bits
|
||||
|
||||
|
||||
def get_page_size_bytes(
|
||||
block_size: int, num_kv_heads: int, head_size: int, kv_cache_dtype: torch.dtype
|
||||
) -> int:
|
||||
"""Returns the size in bytes of one page of the KV cache."""
|
||||
padded_head_size = (
|
||||
cdiv(head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT
|
||||
)
|
||||
num_combined_kv_heads = num_kv_heads * 2
|
||||
|
||||
# NOTE: for the implicit padding in XLA
|
||||
packing = get_dtype_packing(kv_cache_dtype)
|
||||
num_combined_kv_heads = cdiv(num_combined_kv_heads, packing) * packing
|
||||
|
||||
kv_cache_dtype_bits = dtype_bits(kv_cache_dtype)
|
||||
return (
|
||||
block_size * num_combined_kv_heads * padded_head_size * kv_cache_dtype_bits // 8
|
||||
)
|
||||
816
v1/attention/backends/rocm_aiter_fa.py
Normal file
816
v1/attention/backends/rocm_aiter_fa.py
Normal file
@@ -0,0 +1,816 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with AiterFlashAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.merge_attn_states import merge_attn_states
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import get_cu_count
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_prefills_and_extends,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||
|
||||
if current_platform.is_rocm():
|
||||
import aiter
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
def block_size(x, head_dim):
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
def num_programs(total_tokens):
|
||||
return min(total_tokens, get_cu_count())
|
||||
|
||||
@triton.jit
|
||||
def cp_mha_gather_cache_kernel(
|
||||
key_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||
value_cache_ptr, # [num_blocks, page_size, num_head, head_size]
|
||||
key_ptr, # [num_tokens, num_heads, head_size]
|
||||
value_ptr, # [num_tokens, num_heads, head_size]
|
||||
block_table_ptr, # [num_batches, max_block_num]
|
||||
cu_seqlens_kv_ptr, # [num_batches + 1]
|
||||
token_to_batch_ptr, # [max_cum_tokens]
|
||||
seq_start_ptr, # [num_batches]
|
||||
k_scale_ptr,
|
||||
v_scale_ptr,
|
||||
num_heads,
|
||||
head_size,
|
||||
x,
|
||||
max_block_num,
|
||||
num_tokens,
|
||||
num_programs,
|
||||
DEQUANT: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
CACHE_FORMAT: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(0)
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
if DEQUANT:
|
||||
k_scale = tl.load(k_scale_ptr)
|
||||
v_scale = tl.load(v_scale_ptr)
|
||||
|
||||
for token_id in tl.range(bid, num_tokens, num_programs):
|
||||
key_ptr_offset = key_ptr + token_id * head_size * num_heads
|
||||
value_ptr_offset = value_ptr + token_id * head_size * num_heads
|
||||
batch_idx = tl.load(token_to_batch_ptr + token_id)
|
||||
batch_start = tl.load(seq_start_ptr + batch_idx)
|
||||
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
|
||||
batch_offset = token_id - token_start + batch_start
|
||||
block_offset = batch_offset // PAGE_SIZE
|
||||
block_id = tl.load(
|
||||
block_table_ptr + max_block_num * batch_idx + block_offset
|
||||
)
|
||||
slot_id = batch_offset % PAGE_SIZE
|
||||
|
||||
if CACHE_FORMAT == "NHD":
|
||||
# for kv cache layout as
|
||||
# K: [num_blocks, page_size, num_head, head_dim]
|
||||
# V: [num_blocks, page_size, num_head, head_dim]
|
||||
key_cache_ptr_offset = (
|
||||
key_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
)
|
||||
value_cache_ptr_offset = (
|
||||
value_cache_ptr
|
||||
+ block_id * num_heads * head_size * PAGE_SIZE
|
||||
+ slot_id * num_heads * head_size
|
||||
)
|
||||
|
||||
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
|
||||
mask = (col_offsets + i) < head_size * num_heads
|
||||
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
|
||||
if DEQUANT:
|
||||
k_dtype = k_reg.dtype
|
||||
v_dtype = v_reg.dtype
|
||||
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
|
||||
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
|
||||
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
|
||||
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
|
||||
|
||||
def cp_mha_gather_cache(
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
k_scales: torch.Tensor,
|
||||
v_scales: torch.Tensor,
|
||||
cu_seqlens_kv: torch.Tensor,
|
||||
token_to_batch: torch.Tensor,
|
||||
seq_starts: torch.Tensor,
|
||||
dequant: bool,
|
||||
kv_cache_layout: str,
|
||||
total_tokens: int,
|
||||
):
|
||||
assert kv_cache_layout in ["v0", "NHD", "HND"], (
|
||||
"kv_cache_layout only support v0, NHD, HND"
|
||||
)
|
||||
head_dim = key.shape[2]
|
||||
x = 0
|
||||
# assert dequant is True, "Currently, we only support "\
|
||||
# "gather cache with dequant"
|
||||
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
|
||||
assert kv_cache_layout == "NHD", (
|
||||
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
|
||||
)
|
||||
assert head_dim == key_cache.shape[3], (
|
||||
"We assume your kv cache layout is [num_blocks, "
|
||||
"page_size, num_heads, head_dim], but got otherwise"
|
||||
)
|
||||
page_size = key_cache.shape[1]
|
||||
num_heads = key_cache.shape[2]
|
||||
|
||||
NUM_PRGMS = num_programs(total_tokens)
|
||||
BLOCK_SIZE = block_size(key_cache, head_dim)
|
||||
grid = lambda meta: (NUM_PRGMS,)
|
||||
cp_mha_gather_cache_kernel[grid](
|
||||
key_cache,
|
||||
value_cache,
|
||||
key,
|
||||
value,
|
||||
block_tables,
|
||||
cu_seqlens_kv,
|
||||
token_to_batch,
|
||||
seq_starts,
|
||||
k_scales,
|
||||
v_scales,
|
||||
num_heads,
|
||||
head_dim,
|
||||
x,
|
||||
block_tables.size(1),
|
||||
total_tokens,
|
||||
NUM_PRGMS,
|
||||
DEQUANT=dequant,
|
||||
PAGE_SIZE=page_size,
|
||||
CACHE_FORMAT=kv_cache_layout,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionDecodeMetadata:
|
||||
max_query_len: int
|
||||
min_query_len: int
|
||||
max_seq_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionPrefillMetadata:
|
||||
max_query_len: int
|
||||
min_query_len: int
|
||||
max_seq_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterChunkContextMetadata:
|
||||
workspace: torch.Tensor
|
||||
cu_seq_lens_chunk: torch.Tensor
|
||||
chunk_starts: torch.Tensor
|
||||
token_to_batch: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
seq_lens: torch.Tensor
|
||||
num_chunks: int
|
||||
total_token_per_batch: list[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionChunkPrefillMetadata:
|
||||
max_query_len: int
|
||||
min_query_len: int
|
||||
max_seq_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
chunk_context_metadata: AiterChunkContextMetadata
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterFlashAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
num_actual_kv_tokens: int
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
|
||||
# prefill and deocde split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_extends: int
|
||||
num_extend_tokens: int
|
||||
|
||||
decode_metadata: AiterFlashAttentionDecodeMetadata | None
|
||||
prefill_metadata: AiterFlashAttentionPrefillMetadata | None
|
||||
extend_metadata: AiterFlashAttentionChunkPrefillMetadata | None
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class AiterFlashAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[AiterFlashAttentionMetadata]
|
||||
):
|
||||
_cudagraph_support = AttentionCGSupport.UNIFORM_SINGLE_TOKEN_DECODE
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.model_config = vllm_config.model_config
|
||||
self.parallel_config = vllm_config.parallel_config
|
||||
self.cache_config = vllm_config.cache_config
|
||||
|
||||
self.num_heads_q = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config
|
||||
)
|
||||
self.num_heads_kv = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
self.headdim = self.model_config.get_head_size()
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
# Sliding window size to be used with the AOT scheduler will be
|
||||
# populated on first build() call.
|
||||
self.aot_sliding_window: tuple[int, int] | None = None
|
||||
self.total_tokens: int = 0
|
||||
|
||||
self.extend_workspace = torch.empty(
|
||||
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
):
|
||||
self.total_tokens = (
|
||||
self.model_config.max_model_len
|
||||
* self.vllm_config.scheduler_config.max_num_partial_prefills
|
||||
)
|
||||
res = self.build(common_prefix_len=0, common_attn_metadata=common_attn_metadata)
|
||||
self.total_tokens = 0
|
||||
return res
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> "AiterFlashAttentionMetadata":
|
||||
split_ret = split_decodes_prefills_and_extends(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold,
|
||||
)
|
||||
|
||||
(
|
||||
num_decodes,
|
||||
num_extends,
|
||||
num_prefills,
|
||||
num_decode_tokens,
|
||||
num_extend_tokens,
|
||||
num_prefill_tokens,
|
||||
) = split_ret
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu
|
||||
|
||||
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
decode_metadata = AiterFlashAttentionDecodeMetadata(
|
||||
max_query_len=query_lens_cpu[:num_decodes].max().item(),
|
||||
min_query_len=query_lens_cpu[:num_decodes].min().item(),
|
||||
max_seq_len=seq_lens[:num_decodes].max().item(),
|
||||
query_start_loc=common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
)
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
query_lens_for_prefill = query_lens_cpu[num_decodes + num_extends :]
|
||||
query_start_loc_device = common_attn_metadata.query_start_loc[
|
||||
num_decodes + num_extends :
|
||||
]
|
||||
prefill_metadata = AiterFlashAttentionPrefillMetadata(
|
||||
max_query_len=query_lens_for_prefill.max().item(),
|
||||
min_query_len=query_lens_for_prefill.min().item(),
|
||||
max_seq_len=seq_lens[num_decodes + num_extends :].max().item(),
|
||||
query_start_loc=query_start_loc_device - query_start_loc_device[0],
|
||||
)
|
||||
|
||||
extend_metadata = None
|
||||
if num_extends > 0:
|
||||
num_extends_slice = slice(num_decodes, num_decodes + num_extends)
|
||||
query_lens_for_extend = query_lens_cpu[num_extends_slice]
|
||||
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
|
||||
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
|
||||
|
||||
# allocate the equal amount of workspace for
|
||||
# each chunk prefill request
|
||||
max_context_chunk = _CP_TOKENS_PER_ITER_ROCM // num_extends
|
||||
num_chunks = cdiv(computed_kv_lens.max().item(), max_context_chunk)
|
||||
|
||||
chunk_starts = (
|
||||
torch.arange(num_chunks, dtype=torch.int32)
|
||||
.unsqueeze(1)
|
||||
.expand(-1, num_extends)
|
||||
* max_context_chunk
|
||||
)
|
||||
chunk_ends = torch.min(
|
||||
computed_kv_lens.unsqueeze(0), chunk_starts + max_context_chunk
|
||||
)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(
|
||||
min=0
|
||||
) # [num_chunks, num_extends]
|
||||
cu_seq_lens_cpu = torch.zeros(
|
||||
[num_chunks, num_extends + 1], dtype=torch.int32, pin_memory=True
|
||||
)
|
||||
torch.cumsum(
|
||||
chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
|
||||
)
|
||||
max_cum_tokens = cu_seq_lens_cpu[:, -1].max().item()
|
||||
|
||||
range_idx = torch.arange(max_cum_tokens, dtype=torch.int32)[None, None, :]
|
||||
idx_to_batch_tensor = range_idx == cu_seq_lens_cpu[:, 1:][:, :, None]
|
||||
idx_to_batch_tensor = idx_to_batch_tensor.sum(
|
||||
dim=1
|
||||
) # [num_chunks, max_cum_tokens]
|
||||
token_to_batch_tensor = torch.cumsum(idx_to_batch_tensor, dim=1)
|
||||
|
||||
chunk_context_metadata = AiterChunkContextMetadata(
|
||||
workspace=self.extend_workspace,
|
||||
cu_seq_lens_chunk=cu_seq_lens_cpu.to(self.device, non_blocking=True),
|
||||
chunk_starts=chunk_starts.to(self.device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
seq_lens=chunk_seq_lens,
|
||||
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
|
||||
num_chunks=num_chunks,
|
||||
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
|
||||
)
|
||||
|
||||
query_start_loc_device = common_attn_metadata.query_start_loc[
|
||||
num_decodes : num_decodes + num_extends + 1
|
||||
]
|
||||
seq_lens_device = common_attn_metadata.seq_lens[num_extends_slice]
|
||||
cu_seq_lens = torch.zeros(
|
||||
num_extends + 1, dtype=torch.int32, device=seq_lens_device.device
|
||||
)
|
||||
torch.cumsum(
|
||||
seq_lens_device, dim=0, dtype=cu_seq_lens.dtype, out=cu_seq_lens[1:]
|
||||
)
|
||||
extend_metadata = AiterFlashAttentionChunkPrefillMetadata(
|
||||
max_query_len=query_lens_for_extend.max().item(),
|
||||
min_query_len=query_lens_for_extend.min().item(),
|
||||
max_seq_len=seq_lens[num_extends_slice].max().item(),
|
||||
query_start_loc=query_start_loc_device - query_start_loc_device[0],
|
||||
chunk_context_metadata=chunk_context_metadata,
|
||||
)
|
||||
|
||||
num_actual_kv_tokens = torch.sum(seq_lens).item()
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
attn_metadata = AiterFlashAttentionMetadata(
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
num_actual_kv_tokens=num_actual_kv_tokens,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
block_table=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_extends=num_extends,
|
||||
num_extend_tokens=num_extend_tokens,
|
||||
decode_metadata=decode_metadata,
|
||||
prefill_metadata=prefill_metadata,
|
||||
extend_metadata=extend_metadata,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
total_tokens=self.total_tokens,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
def use_cascade_attention(self, *args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class AiterFlashAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [64, 128, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASH_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["AiterFlashAttentionImpl"]:
|
||||
return AiterFlashAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["AiterFlashAttentionMetadataBuilder"]:
|
||||
return AiterFlashAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = [-1, -1]
|
||||
else:
|
||||
self.sliding_window = [sliding_window - 1, 0]
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0.0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashAttentionImpl"
|
||||
)
|
||||
|
||||
def extend_forward(
|
||||
self,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
min_seqlen_q: int,
|
||||
block_table: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
k_scale: float,
|
||||
v_scale: float,
|
||||
):
|
||||
out, lse = aiter.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_q,
|
||||
min_seqlen_q=min_seqlen_q,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
return_lse=True,
|
||||
)
|
||||
assert attn_metadata.extend_metadata is not None
|
||||
chunk_context_metadata = attn_metadata.extend_metadata.chunk_context_metadata
|
||||
num_chunks = chunk_context_metadata.num_chunks
|
||||
workspace = chunk_context_metadata.workspace
|
||||
cu_seqlens_kv = chunk_context_metadata.cu_seq_lens_chunk
|
||||
max_seqlens = chunk_context_metadata.max_seq_lens
|
||||
chunk_starts = chunk_context_metadata.chunk_starts
|
||||
token_to_batch = chunk_context_metadata.token_to_batch
|
||||
total_token_per_batch = chunk_context_metadata.total_token_per_batch
|
||||
key_fetched, value_fetched = workspace[0], workspace[1]
|
||||
chunked_output = None
|
||||
chunked_lse = None
|
||||
for chunk_idx in range(num_chunks):
|
||||
cp_mha_gather_cache(
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
key=key_fetched,
|
||||
value=value_fetched,
|
||||
block_tables=block_table,
|
||||
k_scales=k_scale,
|
||||
v_scales=v_scale,
|
||||
cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
|
||||
token_to_batch=token_to_batch[chunk_idx],
|
||||
seq_starts=chunk_starts[chunk_idx],
|
||||
dequant=False,
|
||||
kv_cache_layout="NHD",
|
||||
total_tokens=total_token_per_batch[chunk_idx],
|
||||
)
|
||||
|
||||
suf_out, suf_lse = aiter.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_fetched,
|
||||
v=value_fetched,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_kv[chunk_idx],
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlens[chunk_idx],
|
||||
min_seqlen_q=min_seqlen_q,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
return_lse=True,
|
||||
)
|
||||
if chunked_output is None:
|
||||
chunked_output = suf_out
|
||||
chunked_lse = suf_lse
|
||||
else:
|
||||
tmp_output = torch.empty_like(out)
|
||||
tmp_lse = torch.empty_like(lse)
|
||||
merge_attn_states(
|
||||
output=tmp_output,
|
||||
output_lse=tmp_lse,
|
||||
prefix_output=chunked_output,
|
||||
prefix_lse=chunked_lse,
|
||||
suffix_output=suf_out,
|
||||
suffix_lse=suf_lse,
|
||||
)
|
||||
chunked_output = tmp_output
|
||||
chunked_lse = tmp_lse
|
||||
|
||||
merge_attn_states(
|
||||
output=output,
|
||||
prefix_output=chunked_output,
|
||||
prefix_lse=chunked_lse,
|
||||
suffix_output=out,
|
||||
suffix_lse=lse,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AiterFlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with AiterFlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
NOTE: FP8 quantization, flash-attn expect the size of
|
||||
{q,k,v}_descale to be (num_sequences, num_kv_heads).
|
||||
We use torch's .expand() to avoid duplicating values
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for FlashAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is
|
||||
# executed in eager-mode PyTorch. Thus, we need to be careful
|
||||
# about any CPU overhead in this method. For example, `view`
|
||||
# and `slice` (or `[:n]`) operations are surprisingly slow even
|
||||
# in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping
|
||||
# is not padded. However, we don't need to do
|
||||
# key[:num_actual_tokens] and value[:num_actual_tokens] because
|
||||
# the reshape_and_cache_flash op uses the slot_mapping's shape
|
||||
# to determine the number of actual tokens.
|
||||
|
||||
torch.ops._C_cache_ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(current_platform.fp8_dtype())
|
||||
value_cache = value_cache.view(current_platform.fp8_dtype())
|
||||
|
||||
# decode:extend:prefill
|
||||
query = query[:num_actual_tokens]
|
||||
key = key[:num_actual_tokens]
|
||||
value = value[:num_actual_tokens]
|
||||
|
||||
output_actual_tokens = output[:num_actual_tokens]
|
||||
|
||||
num_decodes = attn_metadata.num_decodes
|
||||
num_prefills = attn_metadata.num_prefills
|
||||
num_extends = attn_metadata.num_extends
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_extend_tokens = attn_metadata.num_extend_tokens
|
||||
if not attn_metadata.use_cascade:
|
||||
# calculate for pure prefills
|
||||
if num_prefills > 0:
|
||||
assert attn_metadata.prefill_metadata is not None
|
||||
|
||||
prefill_query = query[num_decode_tokens + num_extend_tokens :]
|
||||
prefill_key = key[num_decode_tokens + num_extend_tokens :]
|
||||
prefill_value = value[num_decode_tokens + num_extend_tokens :]
|
||||
|
||||
aiter.flash_attn_varlen_func(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
cu_seqlens_q=attn_metadata.prefill_metadata.query_start_loc,
|
||||
cu_seqlens_k=attn_metadata.prefill_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.prefill_metadata.max_query_len,
|
||||
max_seqlen_k=attn_metadata.prefill_metadata.max_seq_len,
|
||||
min_seqlen_q=1,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
window_size=self.sliding_window,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
out=output_actual_tokens[num_decode_tokens + num_extend_tokens :],
|
||||
)
|
||||
|
||||
# calculate for extends
|
||||
if num_extends > 0:
|
||||
assert attn_metadata.extend_metadata is not None
|
||||
extend_tokens_slice = slice(
|
||||
num_decode_tokens, num_decode_tokens + num_extend_tokens
|
||||
)
|
||||
extend_querys = query[extend_tokens_slice]
|
||||
extend_keys = key[extend_tokens_slice]
|
||||
extend_values = value[extend_tokens_slice]
|
||||
extend_outputs = output[extend_tokens_slice]
|
||||
self.extend_forward(
|
||||
attn_metadata=attn_metadata,
|
||||
query=extend_querys,
|
||||
key=extend_keys,
|
||||
value=extend_values,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
output=extend_outputs,
|
||||
cu_seqlens_q=attn_metadata.extend_metadata.query_start_loc,
|
||||
max_seqlen_q=attn_metadata.extend_metadata.max_query_len,
|
||||
max_seqlen_k=attn_metadata.extend_metadata.max_seq_len,
|
||||
min_seqlen_q=1,
|
||||
block_table=attn_metadata.block_table[
|
||||
num_decodes : num_decodes + num_extends
|
||||
],
|
||||
slot_mapping=attn_metadata.slot_mapping[
|
||||
num_decodes : num_decodes + num_extends
|
||||
],
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
)
|
||||
|
||||
# calculate for decodes
|
||||
if num_decodes > 0:
|
||||
assert attn_metadata.decode_metadata is not None
|
||||
_, num_heads, head_size = query.shape
|
||||
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
|
||||
num_seqs = attn_metadata.seq_lens.shape[0]
|
||||
max_num_partitions = (
|
||||
attn_metadata.max_seq_len + _PARTITION_SIZE_ROCM - 1
|
||||
) // _PARTITION_SIZE_ROCM
|
||||
|
||||
workspace_buffer = torch.empty(
|
||||
(num_seqs * num_heads * max_num_partitions * head_size)
|
||||
* nbytes_per_qo_elem
|
||||
+ 2 * (num_seqs * num_heads * max_num_partitions) * 4,
|
||||
dtype=torch.uint8,
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
torch.ops.aiter.paged_attention_v1(
|
||||
output[:num_decode_tokens],
|
||||
workspace_buffer,
|
||||
query[:num_decode_tokens],
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.scale,
|
||||
attn_metadata.block_table[:num_decodes],
|
||||
attn_metadata.query_start_loc[:num_decodes],
|
||||
attn_metadata.seq_lens[:num_decodes],
|
||||
attn_metadata.max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
"NHD",
|
||||
self.logits_soft_cap,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
None,
|
||||
_PARTITION_SIZE_ROCM,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Cascade attention is not implemented for ROCM AITER"
|
||||
)
|
||||
|
||||
return output
|
||||
196
v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
196
v1/attention/backends/rocm_aiter_unified_attn.py
Normal file
@@ -0,0 +1,196 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_attn import (
|
||||
RocmAttentionBackend,
|
||||
RocmAttentionImpl,
|
||||
RocmAttentionMetadataBuilder,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_AITER_UNIFIED_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAiterUnifiedAttentionImpl"]:
|
||||
return RocmAiterUnifiedAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
sinks,
|
||||
)
|
||||
logger.info_once(
|
||||
"Using aiter unified attention for RocmAiterUnifiedAttentionImpl"
|
||||
)
|
||||
from aiter.ops.triton.unified_attention import unified_attention
|
||||
|
||||
self.unified_attention = unified_attention
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
self.unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
362
v1/attention/backends/rocm_attn.py
Normal file
362
v1/attention/backends/rocm_attn.py
Normal file
@@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with PagedAttention and Triton prefix prefill."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.attention.ops.chunked_prefill_paged_decode import chunked_prefill_paged_decode
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RocmAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
|
||||
|
||||
class RocmAttentionMetadataBuilder(AttentionMetadataBuilder[RocmAttentionMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> RocmAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
|
||||
# Here we set the query start locs to 0. This is to
|
||||
# cover up an invalid memory access in the prefix_prefil kernel
|
||||
# that we run into during graph capture (#25985)
|
||||
common_attn_metadata.query_start_loc.zero_()
|
||||
common_attn_metadata.query_start_loc_cpu.zero_()
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> RocmAttentionMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||
)
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = RocmAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class RocmAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@classmethod
|
||||
def validate_head_size(cls, head_size: int) -> None:
|
||||
if not cls.supports_head_size(head_size):
|
||||
attn_type = cls.__name__.removesuffix("Backend")
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by {attn_type}. "
|
||||
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
|
||||
"Set VLLM_ATTENTION_BACKEND=FLEX_ATTENTION to use "
|
||||
"FlexAttention backend which supports all head sizes."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ROCM_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["RocmAttentionImpl"]:
|
||||
return RocmAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["RocmAttentionMetadataBuilder"]:
|
||||
return RocmAttentionMetadataBuilder
|
||||
|
||||
|
||||
class RocmAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
RocmAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"RocmAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: FlashAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with FlashAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for RocmAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
key_cache, value_cache = PagedAttention.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size
|
||||
)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
PagedAttention.write_to_paged_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
# Compute attention and update output up to `num_actual_tokens`.
|
||||
chunked_prefill_paged_decode(
|
||||
query=query[:num_actual_tokens],
|
||||
key=key[:num_actual_tokens],
|
||||
value=value[:num_actual_tokens],
|
||||
output=output[:num_actual_tokens],
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
block_table=block_table,
|
||||
query_start_loc=cu_seqlens_q,
|
||||
seq_lens=seqused_k,
|
||||
max_seq_len=max_seqlen_k,
|
||||
max_query_len=max_seqlen_q,
|
||||
k_scale=layer._k_scale,
|
||||
v_scale=layer._v_scale,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
sliding_window=self.sliding_window[0],
|
||||
sm_scale=self.scale,
|
||||
output_scale=output_scale,
|
||||
sinks=self.sinks,
|
||||
)
|
||||
|
||||
return output
|
||||
105
v1/attention/backends/short_conv_attn.py
Normal file
105
v1/attention/backends/short_conv_attn.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.v1.attention.backends.mamba_attn import BaseMambaAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
PAD_SLOT_ID,
|
||||
CommonAttentionMetadata,
|
||||
compute_causal_conv1d_metadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
|
||||
class ShortConvAttentionBackend(AttentionBackend):
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]:
|
||||
return ShortConvAttentionMetadataBuilder
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShortConvAttentionMetadata:
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
state_indices_tensor: torch.Tensor
|
||||
has_initial_states_p: torch.Tensor | None
|
||||
|
||||
# For causal_conv1d
|
||||
nums_dict: dict | None = None
|
||||
batch_ptr: torch.Tensor | None = None
|
||||
token_chunk_offset_ptr: torch.Tensor | None = None
|
||||
|
||||
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
BaseMambaAttentionMetadataBuilder[ShortConvAttentionMetadata]
|
||||
):
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> ShortConvAttentionMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0]
|
||||
|
||||
# for causal_conv1d
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = None, None, None
|
||||
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
has_initial_states_p = None
|
||||
if num_prefills > 0:
|
||||
has_initial_states_cpu = (
|
||||
common_attn_metadata.num_computed_tokens_cpu[
|
||||
num_reqs - num_prefills : num_reqs
|
||||
]
|
||||
> 0
|
||||
)
|
||||
has_initial_states_p = has_initial_states_cpu.to(query_start_loc.device)
|
||||
|
||||
query_start_loc_p = (
|
||||
common_attn_metadata.query_start_loc[-num_prefills - 1 :]
|
||||
- num_decode_tokens
|
||||
)
|
||||
|
||||
nums_dict, batch_ptr, token_chunk_offset_ptr = (
|
||||
compute_causal_conv1d_metadata(query_start_loc_p)
|
||||
)
|
||||
|
||||
elif (
|
||||
num_decodes > 0
|
||||
and num_decodes <= self.decode_cudagraph_max_bs
|
||||
and self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
):
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_decodes)
|
||||
self.state_indices_tensor[:num_decodes].copy_(
|
||||
state_indices_tensor, non_blocking=True
|
||||
)
|
||||
state_indices_tensor = self.state_indices_tensor[:num_input_tokens]
|
||||
state_indices_tensor[num_decodes:] = PAD_SLOT_ID
|
||||
|
||||
attn_metadata = ShortConvAttentionMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
state_indices_tensor=state_indices_tensor,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
nums_dict=nums_dict,
|
||||
batch_ptr=batch_ptr,
|
||||
token_chunk_offset_ptr=token_chunk_offset_ptr,
|
||||
)
|
||||
return attn_metadata
|
||||
425
v1/attention/backends/tree_attn.py
Normal file
425
v1/attention/backends/tree_attn.py
Normal file
@@ -0,0 +1,425 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with TreeAttention."""
|
||||
|
||||
import ast
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TreeAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TREE_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TreeAttentionImpl"]:
|
||||
return TreeAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TreeAttentionMetadataBuilder"]:
|
||||
return TreeAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TreeAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
tree_attn_bias: torch.Tensor | None = None
|
||||
|
||||
# Cached Prefill/decode metadata.
|
||||
_cached_prefill_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["TreeAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[self.num_decodes :]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[self.num_decodes :]
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_prefill_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc - q_start_loc[0],
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[self.num_decodes :],
|
||||
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TreeAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[: self.num_decodes + 1]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[: self.num_decodes]
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = TreeAttentionMetadata(
|
||||
num_actual_tokens=self.num_decode_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[: self.num_decodes],
|
||||
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class TreeAttentionMetadataBuilder(AttentionMetadataBuilder[TreeAttentionMetadata]):
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
spec_config = vllm_config.speculative_config
|
||||
spec_token_tree = (spec := spec_config) and spec.speculative_token_tree
|
||||
tree_choices: list[tuple[int, ...]] = (
|
||||
ast.literal_eval(spec_token_tree) if spec_token_tree is not None else [(0,)]
|
||||
)
|
||||
# Construct the tree attention bias.
|
||||
depth_counts = _get_depth_counts(tree_choices)
|
||||
self.tree_attn_bias = _prepare_tree_attn_bias(
|
||||
tree_choices,
|
||||
depth_counts,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
self.reorder_batch_threshold = self.tree_attn_bias.shape[0]
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TreeAttentionMetadata:
|
||||
decode_threshold = self.tree_attn_bias.shape[0]
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=decode_threshold
|
||||
)
|
||||
)
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
return TreeAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
tree_attn_bias=self.tree_attn_bias,
|
||||
)
|
||||
|
||||
def build_for_drafting(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
draft_index: int,
|
||||
) -> TreeAttentionMetadata:
|
||||
# Cache the original tree attention bias.
|
||||
orig_tree_attn_bias = self.tree_attn_bias
|
||||
|
||||
if draft_index == 0:
|
||||
# Use prefill for drafting at the root level.
|
||||
self.tree_attn_bias = torch.empty(0)
|
||||
else:
|
||||
# Slice the tree attention bias for drafting. Exclude
|
||||
# the root level.
|
||||
start, end = 1, 1 + common_attn_metadata.max_query_len
|
||||
self.tree_attn_bias = self.tree_attn_bias[start:end, start:end].contiguous()
|
||||
|
||||
# Build attention bias.
|
||||
attn_metadata = self.build(0, common_attn_metadata, fast_build=True)
|
||||
|
||||
# Reset the tree attention bias to the original value.
|
||||
self.tree_attn_bias = orig_tree_attn_bias
|
||||
return attn_metadata
|
||||
|
||||
|
||||
def _get_depth_counts(sorted_tree_choices: list[tuple[int, ...]]) -> list[int]:
|
||||
# Count the number of choices at each depth of the tree.
|
||||
depth_counts = []
|
||||
prev_depth = 0
|
||||
for path in sorted_tree_choices:
|
||||
depth = len(path)
|
||||
if depth != prev_depth:
|
||||
depth_counts.append(0)
|
||||
depth_counts[depth - 1] += 1
|
||||
prev_depth = depth
|
||||
return depth_counts
|
||||
|
||||
|
||||
def _prepare_tree_attn_bias(
|
||||
sorted_tree_choices: list[tuple[int, ...]],
|
||||
depth_counts: list[int],
|
||||
dtype: torch.dtype | None,
|
||||
device: torch.device | None,
|
||||
) -> torch.Tensor:
|
||||
# +1 comes from the additional root node.
|
||||
tree_len = len(sorted_tree_choices) + 1
|
||||
tree_attn_mask = torch.full(
|
||||
(tree_len, tree_len), -torch.inf, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
# Set diagonal to all zeros. Each token should
|
||||
# attend to itself.
|
||||
mask_val = 0
|
||||
for i in range(tree_len):
|
||||
tree_attn_mask[i, i] = mask_val
|
||||
|
||||
# Set root to all zeros. All tokens attend to it.
|
||||
tree_attn_mask[:, 0] = mask_val
|
||||
|
||||
# Set all ancestors to zeros.
|
||||
start = 0
|
||||
for i in range(len(depth_counts)):
|
||||
for j in range(depth_counts[i]):
|
||||
cur_tree_choice = sorted_tree_choices[start + j]
|
||||
# Retrieve ancestor position.
|
||||
if len(cur_tree_choice) == 1:
|
||||
continue
|
||||
ancestor_idx = []
|
||||
for c in range(len(cur_tree_choice) - 1):
|
||||
ancestor_idx.append(
|
||||
sorted_tree_choices.index(cur_tree_choice[: c + 1]) + 1
|
||||
)
|
||||
tree_attn_mask[j + start + 1, ancestor_idx] = mask_val
|
||||
start += depth_counts[i]
|
||||
return tree_attn_mask
|
||||
|
||||
|
||||
class TreeAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if logits_soft_cap is None:
|
||||
# Setting logits_soft_cap to 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TreeAttentionImpl."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TreeAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with TreeAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported for TreeAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
descale_shape = (attn_metadata.query_start_loc.shape[0] - 1, key.shape[1])
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
unified_attention(
|
||||
q=query[num_decode_tokens:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:num_actual_tokens],
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=prefill_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
unified_attention(
|
||||
q=query[:num_decode_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_decode_tokens],
|
||||
cu_seqlens_q=decode_meta.query_start_loc,
|
||||
max_seqlen_q=decode_meta.max_query_len,
|
||||
seqused_k=decode_meta.seq_lens,
|
||||
max_seqlen_k=decode_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
qq_bias=decode_meta.tree_attn_bias,
|
||||
window_size=self.sliding_window,
|
||||
block_table=decode_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
return output
|
||||
373
v1/attention/backends/triton_attn.py
Normal file
373
v1/attention/backends/triton_attn.py
Normal file
@@ -0,0 +1,373 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""High-Performance Triton-only Attention layer."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8StaticTensorSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TritonAttentionMetadata:
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
# For cascade attention.
|
||||
use_cascade: bool
|
||||
common_prefix_len: int
|
||||
cu_prefix_query_lens: torch.Tensor | None
|
||||
prefix_kv_lens: torch.Tensor | None
|
||||
suffix_kv_lens: torch.Tensor | None
|
||||
|
||||
# Optional aot scheduling
|
||||
scheduler_metadata: torch.Tensor | None = None
|
||||
prefix_scheduler_metadata: torch.Tensor | None = None
|
||||
|
||||
|
||||
class TritonAttentionMetadataBuilder(AttentionMetadataBuilder[TritonAttentionMetadata]):
|
||||
_cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.ALWAYS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
self.num_heads_q = model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
self.num_heads_kv = model_config.get_num_kv_heads(vllm_config.parallel_config)
|
||||
self.headdim = model_config.get_head_size()
|
||||
|
||||
def build_for_cudagraph_capture(
|
||||
self, common_attn_metadata: CommonAttentionMetadata
|
||||
) -> TritonAttentionMetadata:
|
||||
attn_metadata = self.build(0, common_attn_metadata)
|
||||
# When doing full graph capture, setting seq_lens to
|
||||
# max_model_len will cause graph capture to be extremely
|
||||
# slow, so here we set it to 1.
|
||||
attn_metadata.seq_lens.fill_(1)
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> TritonAttentionMetadata:
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
seq_lens = common_attn_metadata.seq_lens
|
||||
block_table_tensor = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
use_cascade = common_prefix_len > 0
|
||||
|
||||
if use_cascade:
|
||||
cu_prefix_query_lens = torch.tensor(
|
||||
[0, num_actual_tokens], dtype=torch.int32, device=self.device
|
||||
)
|
||||
prefix_kv_lens = torch.tensor(
|
||||
[common_prefix_len], dtype=torch.int32, device=self.device
|
||||
)
|
||||
suffix_kv_lens = common_attn_metadata.seq_lens_cpu - common_prefix_len
|
||||
suffix_kv_lens = suffix_kv_lens.to(self.device)
|
||||
else:
|
||||
cu_prefix_query_lens = None
|
||||
prefix_kv_lens = None
|
||||
suffix_kv_lens = None
|
||||
prefix_scheduler_metadata = None
|
||||
|
||||
attn_metadata = TritonAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=query_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=seq_lens,
|
||||
block_table=block_table_tensor,
|
||||
slot_mapping=slot_mapping,
|
||||
use_cascade=use_cascade,
|
||||
common_prefix_len=common_prefix_len,
|
||||
cu_prefix_query_lens=cu_prefix_query_lens,
|
||||
prefix_kv_lens=prefix_kv_lens,
|
||||
suffix_kv_lens=suffix_kv_lens,
|
||||
prefix_scheduler_metadata=prefix_scheduler_metadata,
|
||||
)
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class TritonAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float32,
|
||||
]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
|
||||
"auto",
|
||||
"fp8",
|
||||
"fp8_e4m3",
|
||||
"fp8_e5m2",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_ATTN"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["TritonAttentionImpl"]:
|
||||
return TritonAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (num_blocks, 2, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["TritonAttentionMetadataBuilder"]:
|
||||
return TritonAttentionMetadataBuilder
|
||||
|
||||
@classmethod
|
||||
def supports_head_size(cls, head_size: int) -> bool:
|
||||
return head_size >= 32
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TritonAttentionImpl(AttentionImpl):
|
||||
def fused_output_quant_supported(self, quant_key: QuantKey):
|
||||
return quant_key == kFp8StaticTensorSym
|
||||
|
||||
def supports_quant_query_input(self) -> bool:
|
||||
return current_platform.is_cuda()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: int | None = None,
|
||||
sinks: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
if logits_soft_cap is None:
|
||||
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonAttentionImpl"
|
||||
)
|
||||
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
|
||||
self.sinks = sinks
|
||||
if sinks is not None:
|
||||
assert sinks.shape[0] == num_heads, (
|
||||
"Sinks must have the same number of heads as the number of "
|
||||
f"heads in the layer. Sinks shape: {sinks.shape}, "
|
||||
f"num_heads: {num_heads}."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: TritonAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Paged Attention impl. in Triton.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[num_blocks, 2, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused block_scale output quantization is not yet supported"
|
||||
" for TritonAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
assert attn_metadata.use_cascade is False
|
||||
|
||||
# IMPORTANT!
|
||||
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
|
||||
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
|
||||
# in this method. For example, `view` and `slice` (or `[:n]`) operations
|
||||
# are surprisingly slow even in the case they do not invoke any GPU ops.
|
||||
# Minimize the PyTorch ops in this method as much as possible.
|
||||
# Whenever making a change in this method, please benchmark the
|
||||
# performance to make sure it does not introduce any overhead.
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
key_cache, value_cache = kv_cache.unbind(1)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
# triton kernel does not support uint8 kv_cache
|
||||
# (because some explicit casts (e.g. float8_e4m3fnuz)
|
||||
# are not supported)
|
||||
triton_reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
if key_cache.dtype != self.fp8_dtype:
|
||||
key_cache = key_cache.view(self.fp8_dtype)
|
||||
value_cache = value_cache.view(self.fp8_dtype)
|
||||
assert layer._q_scale_float == 1.0, (
|
||||
"A non 1.0 q_scale is not currently supported."
|
||||
)
|
||||
|
||||
cu_seqlens_q = attn_metadata.query_start_loc
|
||||
seqused_k = attn_metadata.seq_lens
|
||||
max_seqlen_q = attn_metadata.max_query_len
|
||||
max_seqlen_k = attn_metadata.max_seq_len
|
||||
block_table = attn_metadata.block_table
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
|
||||
unified_attention(
|
||||
q=query[:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[:num_actual_tokens],
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
seqused_k=seqused_k,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
sinks=self.sinks,
|
||||
output_scale=output_scale,
|
||||
)
|
||||
|
||||
return output
|
||||
1117
v1/attention/backends/utils.py
Normal file
1117
v1/attention/backends/utils.py
Normal file
File diff suppressed because it is too large
Load Diff
417
v1/attention/backends/xformers.py
Normal file
417
v1/attention/backends/xformers.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Attention layer with XFormersAttention."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import (
|
||||
AttentionBackend,
|
||||
AttentionImpl,
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.attention.ops.triton_unified_attention import unified_attention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
try:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import (
|
||||
AttentionBias,
|
||||
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
)
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XFormersAttentionBackend(AttentionBackend):
|
||||
accept_output_buffer: bool = True
|
||||
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
|
||||
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
|
||||
|
||||
@classmethod
|
||||
def get_supported_head_sizes(cls) -> list[int]:
|
||||
return [
|
||||
32,
|
||||
40,
|
||||
48,
|
||||
56,
|
||||
64,
|
||||
72,
|
||||
80,
|
||||
88,
|
||||
96,
|
||||
104,
|
||||
112,
|
||||
120,
|
||||
128,
|
||||
136,
|
||||
144,
|
||||
152,
|
||||
160,
|
||||
168,
|
||||
176,
|
||||
184,
|
||||
192,
|
||||
200,
|
||||
208,
|
||||
216,
|
||||
224,
|
||||
232,
|
||||
240,
|
||||
248,
|
||||
256,
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "XFORMERS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> type["XFormersAttentionImpl"]:
|
||||
return XFormersAttentionImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
cache_dtype_str: str = "auto",
|
||||
) -> tuple[int, ...]:
|
||||
if block_size % 16 != 0:
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> type["XFormersAttentionMetadataBuilder"]:
|
||||
return XFormersAttentionMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def use_cascade_attention(*args, **kwargs) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@dataclass
|
||||
class XFormersAttentionMetadata:
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
max_query_len: int
|
||||
query_start_loc: torch.Tensor
|
||||
max_seq_len: int
|
||||
seq_lens: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
slot_mapping: torch.Tensor
|
||||
|
||||
num_prefill_tokens: int = 0
|
||||
num_decode_tokens: int = 0
|
||||
num_prefills: int = 0
|
||||
num_decodes: int = 0
|
||||
|
||||
# Biases for different attention types.
|
||||
attn_bias: Optional["AttentionBias"] = None
|
||||
|
||||
# Self-attention prefill/decode metadata cache
|
||||
_cached_prefill_metadata: Optional["XFormersAttentionMetadata"] = None
|
||||
_cached_decode_metadata: Optional["XFormersAttentionMetadata"] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["XFormersAttentionMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
# Recover cached prefill-phase attention
|
||||
# metadata structure
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc[self.num_decodes :]
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
kv_seqlens = self.seq_lens[self.num_decodes :]
|
||||
# Construct & cache prefill-phase attention metadata structure
|
||||
self._cached_prefill_metadata = XFormersAttentionMetadata(
|
||||
num_actual_tokens=self.num_prefill_tokens,
|
||||
max_query_len=int(q_seqlens.max().item()),
|
||||
query_start_loc=q_start_loc - q_start_loc[0],
|
||||
max_seq_len=int(kv_seqlens.max().item()),
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=self.block_table[self.num_decodes :],
|
||||
slot_mapping=self.slot_mapping[self.num_decode_tokens :],
|
||||
)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
# Recover cached decode-phase attention
|
||||
# metadata structure
|
||||
return self._cached_decode_metadata
|
||||
|
||||
q_start_loc = self.query_start_loc
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
decode_kv_seqlens = self.seq_lens[: self.num_decodes]
|
||||
# Construct & cache decode-phase attention metadata structure
|
||||
self._cached_decode_metadata = XFormersAttentionMetadata(
|
||||
num_actual_tokens=self.num_decode_tokens,
|
||||
max_query_len=int(q_seqlens[: self.num_decodes].max().item()),
|
||||
query_start_loc=q_start_loc[: self.num_decodes + 1],
|
||||
max_seq_len=int(decode_kv_seqlens.max().item()),
|
||||
seq_lens=decode_kv_seqlens,
|
||||
block_table=self.block_table[: self.num_decodes],
|
||||
slot_mapping=self.slot_mapping[: self.num_decode_tokens],
|
||||
attn_bias=self.attn_bias,
|
||||
)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
|
||||
class XFormersAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[XFormersAttentionMetadata]
|
||||
):
|
||||
reorder_batch_threshold: int = 1
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
|
||||
assert XFORMERS_AVAILABLE
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self._num_decodes = 0
|
||||
self._num_decode_tokens = 0
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> XFormersAttentionMetadata:
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
q_start_loc = common_attn_metadata.query_start_loc
|
||||
q_seqlens = torch.diff(q_start_loc)
|
||||
max_query_len = common_attn_metadata.max_query_len
|
||||
kv_seqlens = common_attn_metadata.seq_lens
|
||||
max_seq_len = common_attn_metadata.max_seq_len
|
||||
block_table = common_attn_metadata.block_table_tensor
|
||||
slot_mapping = common_attn_metadata.slot_mapping
|
||||
|
||||
bias = None
|
||||
if num_decodes > 0:
|
||||
# Construct the decoder bias.
|
||||
decode_q_seqlens = q_seqlens[:num_decodes]
|
||||
decode_kv_seqlens = kv_seqlens[:num_decodes]
|
||||
bias = PagedBlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
|
||||
q_seqlen=decode_q_seqlens.tolist(),
|
||||
kv_seqlen=decode_kv_seqlens.tolist(),
|
||||
page_size=self.block_size,
|
||||
block_tables=block_table[:num_decodes],
|
||||
device=block_table.device,
|
||||
)
|
||||
|
||||
return XFormersAttentionMetadata(
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_decodes=num_decodes,
|
||||
max_query_len=max_query_len,
|
||||
query_start_loc=q_start_loc,
|
||||
max_seq_len=max_seq_len,
|
||||
seq_lens=kv_seqlens,
|
||||
block_table=block_table,
|
||||
slot_mapping=slot_mapping,
|
||||
attn_bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class XFormersAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: list[float] | None,
|
||||
sliding_window: int | None,
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: float | None = None,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("XFormers does not support alibi slopes yet.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
if logits_soft_cap is None:
|
||||
# Setting logits_soft_cap to 0 means no soft cap.
|
||||
logits_soft_cap = 0
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError(
|
||||
"Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"XFormersAttentionImpl."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: XFormersAttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with XFormers.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads, head_size]
|
||||
key: shape = [num_tokens, num_kv_heads, head_size]
|
||||
value: shape = [num_tokens, num_kv_heads, head_size]
|
||||
kv_cache: shape =
|
||||
[2, num_blocks, block_size, num_kv_heads, head_size]
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
|
||||
if output_scale is not None or output_block_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for XFormersAttentionImpl"
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output.fill_(0)
|
||||
|
||||
# Cache the input KVs.
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
# not padded. However, we don't need to do key[:num_actual_tokens]
|
||||
# and value[:num_actual_tokens] because the reshape_and_cache_flash
|
||||
# op uses the slot_mapping's shape to determine the number of
|
||||
# actual tokens.
|
||||
ops.reshape_and_cache_flash(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale,
|
||||
layer._v_scale,
|
||||
)
|
||||
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
if prefill_meta := attn_metadata.prefill_metadata:
|
||||
descale_shape = (prefill_meta.query_start_loc.shape[0] - 1, key.shape[1])
|
||||
unified_attention(
|
||||
q=query[num_decode_tokens:num_actual_tokens],
|
||||
k=key_cache,
|
||||
v=value_cache,
|
||||
out=output[num_decode_tokens:num_actual_tokens],
|
||||
cu_seqlens_q=prefill_meta.query_start_loc,
|
||||
max_seqlen_q=prefill_meta.max_query_len,
|
||||
seqused_k=prefill_meta.seq_lens,
|
||||
max_seqlen_k=prefill_meta.max_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
alibi_slopes=self.alibi_slopes,
|
||||
window_size=self.sliding_window,
|
||||
block_table=prefill_meta.block_table,
|
||||
softcap=self.logits_soft_cap,
|
||||
q_descale=None, # Not supported
|
||||
k_descale=layer._k_scale.expand(descale_shape),
|
||||
v_descale=layer._v_scale.expand(descale_shape),
|
||||
)
|
||||
|
||||
if decode_meta := attn_metadata.decode_metadata:
|
||||
# Query for decode. KV is not needed because it is already cached.
|
||||
decode_query = query[:num_decode_tokens]
|
||||
# Reshape query to [1, B_T, G, H, D].
|
||||
q = decode_query.view(
|
||||
1, -1, self.num_kv_heads, self.num_queries_per_kv, self.head_size
|
||||
)
|
||||
# Reshape the k and v caches to [1, Bkv_T, G, H, D]
|
||||
cache_k = key_cache.view(
|
||||
1, -1, self.num_kv_heads, 1, self.head_size
|
||||
).expand(
|
||||
1,
|
||||
-1,
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
self.head_size,
|
||||
)
|
||||
cache_v = value_cache.view(
|
||||
1, -1, self.num_kv_heads, 1, self.head_size
|
||||
).expand(
|
||||
1,
|
||||
-1,
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
self.head_size,
|
||||
)
|
||||
|
||||
attn_bias = decode_meta.attn_bias
|
||||
output[:num_decode_tokens] = xops.memory_efficient_attention_forward(
|
||||
q,
|
||||
cache_k,
|
||||
cache_v,
|
||||
attn_bias=attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
).view(decode_query.shape)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output
|
||||
0
v1/core/__init__.py
Normal file
0
v1/core/__init__.py
Normal file
BIN
v1/core/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/core/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/__pycache__/block_pool.cpython-312.pyc
Normal file
BIN
v1/core/__pycache__/block_pool.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/__pycache__/encoder_cache_manager.cpython-312.pyc
Normal file
BIN
v1/core/__pycache__/encoder_cache_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/__pycache__/kv_cache_coordinator.cpython-312.pyc
Normal file
BIN
v1/core/__pycache__/kv_cache_coordinator.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/__pycache__/kv_cache_manager.cpython-312.pyc
Normal file
BIN
v1/core/__pycache__/kv_cache_manager.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/__pycache__/kv_cache_utils.cpython-312.pyc
Normal file
BIN
v1/core/__pycache__/kv_cache_utils.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/__pycache__/single_type_kv_cache_manager.cpython-312.pyc
Normal file
BIN
v1/core/__pycache__/single_type_kv_cache_manager.cpython-312.pyc
Normal file
Binary file not shown.
428
v1/core/block_pool.py
Normal file
428
v1/core/block_pool.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm.distributed.kv_events import (
|
||||
MEDIUM_GPU,
|
||||
AllBlocksCleared,
|
||||
BlockRemoved,
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
BlockHashWithGroupId,
|
||||
ExternalBlockHash,
|
||||
FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
get_block_hash,
|
||||
make_block_hash_with_group_id,
|
||||
maybe_convert_block_hash,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockHashToBlockMap:
|
||||
"""
|
||||
Cache of blocks that are used for prefix caching. It caches blocks
|
||||
from hash directly to a block or multiple blocks
|
||||
(i.e. {block_hash: KVCacheBlocks})
|
||||
- Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks
|
||||
would simply be a KVCacheBlock.
|
||||
- Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock}
|
||||
|
||||
A cached block is a full block with a block hash that can be used
|
||||
for prefix caching.
|
||||
The cached block may be used by running requests or in the
|
||||
free_block_queue that could potentially be evicted.
|
||||
|
||||
NOTE #1: We currently don't de-duplicate the blocks in the cache,
|
||||
meaning that if a block becomes full and is cached, we don't check
|
||||
if there is already an identical block in the cache. This is because
|
||||
we want to make sure the allocated block IDs won't change so that
|
||||
block tables are append-only.
|
||||
NOTE #2: The union type is introduced in order to reduce GC costs
|
||||
from the inner dict.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[
|
||||
BlockHashWithGroupId, KVCacheBlock | dict[int, KVCacheBlock]
|
||||
] = {}
|
||||
|
||||
def get_one_block(self, key: BlockHashWithGroupId) -> KVCacheBlock | None:
|
||||
"""
|
||||
Gets any block with the given block hash key.
|
||||
"""
|
||||
blocks = self._cache.get(key)
|
||||
if blocks is not None:
|
||||
if isinstance(blocks, KVCacheBlock):
|
||||
return blocks
|
||||
if isinstance(blocks, dict):
|
||||
return next(iter(blocks.values()))
|
||||
self._unexpected_blocks_type(blocks)
|
||||
return None
|
||||
|
||||
def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None:
|
||||
"""
|
||||
Inserts the KVCacheBlock to the cache
|
||||
"""
|
||||
blocks = self._cache.get(key)
|
||||
if blocks is None:
|
||||
# When key is not found, attach a single block to the key
|
||||
self._cache[key] = block
|
||||
elif isinstance(blocks, KVCacheBlock):
|
||||
# If there's a block with the same key, merge the original block
|
||||
# and the new block into a dict
|
||||
self._cache[key] = {blocks.block_id: blocks, block.block_id: block}
|
||||
elif isinstance(blocks, dict):
|
||||
# If it's already a dict, simply insert the block
|
||||
blocks[block.block_id] = block
|
||||
else:
|
||||
self._unexpected_blocks_type(blocks)
|
||||
|
||||
def pop(self, key: BlockHashWithGroupId, block_id: int) -> KVCacheBlock | None:
|
||||
"""
|
||||
Checks if block_hash exists and pop block_id from the cache
|
||||
"""
|
||||
blocks = self._cache.pop(key, None)
|
||||
if blocks is None:
|
||||
# block_hash not found in the cache
|
||||
return None
|
||||
# TODO(Jialin): If key is found, block_id should always present
|
||||
# in blocks. We currently keep the original behaviour for safety.
|
||||
#
|
||||
# Will add block_id == blocks.block_id assertion and
|
||||
# use del blocks[block_id] instead as followup.
|
||||
if isinstance(blocks, KVCacheBlock):
|
||||
if blocks.block_id == block_id:
|
||||
return blocks
|
||||
# If the single block ID doesn't match, we should put the
|
||||
# block back (it should happen rarely)
|
||||
self._cache[key] = blocks
|
||||
return None
|
||||
if isinstance(blocks, dict):
|
||||
# Try to pop block_id from the block dict, and if dict still
|
||||
# contain blocks, put back to the cache.
|
||||
block = blocks.pop(block_id, None)
|
||||
if len(blocks) > 0:
|
||||
self._cache[key] = blocks
|
||||
return block
|
||||
self._unexpected_blocks_type(blocks)
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._cache)
|
||||
|
||||
def _unexpected_blocks_type(self, blocks: Any) -> None:
|
||||
raise AssertionError(f"Invalid KV cache block type {type(blocks)}")
|
||||
|
||||
|
||||
class BlockPool:
|
||||
"""BlockPool that manages KVCacheBlocks.
|
||||
It provides methods to allocate, free and cache the kv cache blocks. The
|
||||
free_block_queue stores the free blocks in eviction order to enable
|
||||
allocation, free, and cache eviction. The cached_block_hash_to_block
|
||||
maps between block hash and cached block to support finding cached blocks
|
||||
by their block hash.
|
||||
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool = False,
|
||||
):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
# All kv-cache blocks.
|
||||
self.blocks: list[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
]
|
||||
# Free block queue that constructs and manipulates a doubly linked
|
||||
# list of free blocks (including eviction candidates when caching is
|
||||
# enabled).
|
||||
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
||||
|
||||
# Cache for block lookup
|
||||
self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap()
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
# The ref_cnt of null_block is not maintained, needs special care to
|
||||
# avoid freeing it.
|
||||
self.null_block = self.free_block_queue.popleft()
|
||||
self.null_block.is_null = True
|
||||
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
|
||||
def get_cached_block(
|
||||
self, block_hash: BlockHash, kv_cache_group_ids: list[int]
|
||||
) -> list[KVCacheBlock] | None:
|
||||
"""Get the cached block by the block hash for each group in
|
||||
`kv_cache_group_ids`, or None if cache miss for any group.
|
||||
If there are duplicated blocks, we return the first block in the cache.
|
||||
|
||||
Args:
|
||||
block_hash: The hash value of the block.
|
||||
kv_cache_group_ids: The ids of the KV cache groups.
|
||||
|
||||
Returns:
|
||||
The cached blocks if exists, or None.
|
||||
"""
|
||||
cached_blocks = []
|
||||
for group_id in kv_cache_group_ids:
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, group_id
|
||||
)
|
||||
block = self.cached_block_hash_to_block.get_one_block(
|
||||
block_hash_with_group_id
|
||||
)
|
||||
if not block:
|
||||
return None
|
||||
cached_blocks.append(block)
|
||||
return cached_blocks
|
||||
|
||||
def cache_full_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: list[KVCacheBlock],
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
kv_cache_group_id: int,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it updates the
|
||||
metadata for each block and caching it in the
|
||||
`cached_block_hash_to_block`.
|
||||
The block hashes values are computed by the Request object immediately
|
||||
when it is created and when new tokens are appended.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
blocks: All blocks in the request.
|
||||
num_cached_blocks: The number of blocks that are already cached.
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
"""
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
new_block_hashes = request.block_hashes[num_cached_blocks:]
|
||||
|
||||
new_hashes: list[ExternalBlockHash] | None = (
|
||||
[] if self.enable_kv_cache_events else None
|
||||
)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
assert blk.block_hash is None
|
||||
block_hash = new_block_hashes[i]
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, kv_cache_group_id
|
||||
)
|
||||
blk.block_hash = block_hash_with_group_id
|
||||
self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(maybe_convert_block_hash(block_hash))
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
if num_cached_blocks == 0:
|
||||
parent_block_hash: ExternalBlockHash | None = None
|
||||
else:
|
||||
parent_block = blocks[num_cached_blocks - 1]
|
||||
assert parent_block.block_hash is not None
|
||||
parent_block_hash = maybe_convert_block_hash(
|
||||
get_block_hash(parent_block.block_hash)
|
||||
)
|
||||
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=new_hashes,
|
||||
parent_block_hash=parent_block_hash,
|
||||
token_ids=request.all_token_ids[
|
||||
num_cached_blocks * block_size : num_full_blocks * block_size
|
||||
],
|
||||
block_size=block_size,
|
||||
lora_id=request.lora_request.adapter_id
|
||||
if request.lora_request
|
||||
else None,
|
||||
medium=MEDIUM_GPU,
|
||||
)
|
||||
)
|
||||
|
||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
Note that we do not check block cache in this function.
|
||||
|
||||
Args:
|
||||
num_blocks: The number of blocks to allocate.
|
||||
|
||||
Returns:
|
||||
A list of new block.
|
||||
"""
|
||||
if num_blocks > self.get_num_free_blocks():
|
||||
raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
|
||||
|
||||
# In order to only iterate the list once, we duplicated code a bit
|
||||
if self.enable_caching:
|
||||
for block in ret:
|
||||
self._maybe_evict_cached_block(block)
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
else:
|
||||
for block in ret:
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
return ret
|
||||
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
"""
|
||||
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
||||
metadata and evict it from the cache.
|
||||
|
||||
Args:
|
||||
block: The block to evict.
|
||||
|
||||
Returns:
|
||||
True if the block is evicted, False otherwise.
|
||||
"""
|
||||
block_hash = block.block_hash
|
||||
if block_hash is None:
|
||||
# The block doesn't have hash, eviction is not needed
|
||||
return False
|
||||
|
||||
if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None:
|
||||
# block not found in cached_block_hash_to_block,
|
||||
# eviction is not needed
|
||||
return False
|
||||
|
||||
block.reset_hash()
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||
# or `(hash_value, group_id)` here. But it's fine now because
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(
|
||||
block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))],
|
||||
medium=MEDIUM_GPU,
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for blocks_per_group in blocks:
|
||||
for block in blocks_per_group:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.ref_cnt += 1
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
eviction priority, where the first block will be evicted first.
|
||||
|
||||
Args:
|
||||
ordered_blocks: A list of blocks to free ordered by their eviction
|
||||
priority.
|
||||
"""
|
||||
# Materialize the iterable to allow multiple passes.
|
||||
blocks_list = list(ordered_blocks)
|
||||
for block in blocks_list:
|
||||
block.ref_cnt -= 1
|
||||
self.free_block_queue.append_n(
|
||||
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
|
||||
)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
|
||||
if num_used_blocks != 1: # The null block is always marked as used
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet",
|
||||
num_used_blocks - 1,
|
||||
)
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
self.cached_block_hash_to_block = BlockHashToBlockMap()
|
||||
|
||||
# Remove all hashes from all blocks.
|
||||
for block in self.blocks:
|
||||
block.reset_hash()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(AllBlocksCleared())
|
||||
|
||||
return True
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the number of free blocks in the pool.
|
||||
|
||||
Returns:
|
||||
The number of free blocks.
|
||||
"""
|
||||
return self.free_block_queue.num_free_blocks
|
||||
|
||||
def get_usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
|
||||
# Subtract 1 to account for null block.
|
||||
total_gpu_blocks = self.num_gpu_blocks - 1
|
||||
if not total_gpu_blocks:
|
||||
return 0
|
||||
return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Atomically takes all events and clears the queue.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
if not self.enable_kv_cache_events:
|
||||
return []
|
||||
events = self.kv_event_queue
|
||||
self.kv_event_queue = []
|
||||
return events
|
||||
343
v1/core/encoder_cache_manager.py
Normal file
343
v1/core/encoder_cache_manager.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.v1.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EncoderCacheManager:
|
||||
"""Manages caching of encoder outputs for multimodal models in vLLM V1.
|
||||
|
||||
The EncoderCacheManager handles the lifecycle of multimodal encoder outputs
|
||||
(such as vision embeddings from images) during request processing. It
|
||||
provides memory-aware caching to avoid recomputing encoder outputs when the
|
||||
same multimodal inputs appear in different stages of request processing.
|
||||
|
||||
This manager is particularly important for:
|
||||
- Vision-language models (e.g., LLaVA) where image encoder outputs are
|
||||
cached
|
||||
- Any multimodal model where encoder computation is expensive and
|
||||
cacheable
|
||||
|
||||
The cache operates at the granularity of individual multimodal input items
|
||||
within requests, allowing for fine-grained memory management and enabling
|
||||
chunked processing of multimodal inputs.
|
||||
|
||||
Cache is enabled to share embeddings of same multimodal data
|
||||
item (identified by their hash value) between different requests,
|
||||
and eviction takes place at allocation time when there's no free
|
||||
space for new embeddings.
|
||||
Oldest cached embeddings with no request referenced will be first evicted.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
tokens from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder tokens.
|
||||
num_free_slots: Current available cache capacity in encoder tokens.
|
||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||
evicting entries with zero references (in encoder tokens).
|
||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||
reference the cached entry. If the set is empty, the entry exists
|
||||
but is not referenced by any request and is eligible for
|
||||
reclamation.
|
||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
||||
whose no current running request is needed and that can be freed to
|
||||
make space when needed.
|
||||
freed: List of mm_hash strings that were actually evicted since the
|
||||
last call to get_freed_mm_hashes(). This list is cleared on return.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
self.num_free_slots = cache_size
|
||||
self.num_freeable_slots = cache_size
|
||||
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||
"""Check if encoder output for a specific multimodal input is cached.
|
||||
|
||||
If the encoder output is cached, update `cached` to add the request id
|
||||
to the set of request ids that reference the cached encoder output.
|
||||
If the encoder output was previously not referenced by any request,
|
||||
update `freeable` and `num_freeable_slots` accordingly.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
|
||||
Returns:
|
||||
True if the encoder output for this input is already cached
|
||||
"""
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
# Not cached at all
|
||||
if mm_hash not in self.cached:
|
||||
return False
|
||||
|
||||
# Cached but currently not referenced by any request
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_tokens
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
return True
|
||||
|
||||
def can_allocate(
|
||||
self,
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
If there is, return True and update EncoderCacheManager state.
|
||||
|
||||
If there is not enough free space in `num_free_slots` but there is
|
||||
enough reclaimable space in `num_freeable_slots`, entries will be
|
||||
evicted from `freeable` (their mm_hash appended to `freed`) until
|
||||
enough space is available, and then this method returns True.
|
||||
Older entries are evicted first.
|
||||
|
||||
Returns False only if the requested number of tokens exceeds both
|
||||
the free and reclaimable capacities combined.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input.
|
||||
input_id: Index of the multimodal input within the request.
|
||||
encoder_compute_budget: Number of encoder tokens allowed to be
|
||||
computed when this method is invoked.
|
||||
num_tokens_to_schedule: Number of tokens already scheduled to be
|
||||
allocated with cache space when this method is invoked.
|
||||
|
||||
Returns:
|
||||
True if there's enough capacity to hold the encoder output for this
|
||||
input (possibly after reclaiming `freeable` entries); otherwise
|
||||
False.
|
||||
|
||||
Note: This method does not allocate physical memory for the encoder
|
||||
output but only the state of EncoderCacheManager.
|
||||
"""
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
|
||||
# Enough free slots
|
||||
if num_tokens <= self.num_free_slots:
|
||||
return True
|
||||
|
||||
# Not enough reclaimable slots
|
||||
if num_tokens > self.num_freeable_slots:
|
||||
return False
|
||||
|
||||
# Not enough free slots but enough reclaimable slots
|
||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||
# until model runner is notified by the scheduler output.
|
||||
while num_tokens > self.num_free_slots:
|
||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
||||
del self.cached[mm_hash]
|
||||
self.freed.append(mm_hash)
|
||||
self.num_free_slots += num_free_token
|
||||
return True
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
"""Allocate cache space for a multimodal input's encoder output.
|
||||
|
||||
This reserves cache space for storing the encoder output of the
|
||||
specified multimodal input. The actual encoder output storage happens in
|
||||
the model runner; this method updates the manager's bookkeeping.
|
||||
|
||||
Note:
|
||||
This method assumes can_allocate() returned True for the same input.
|
||||
"""
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
request_id = request.request_id
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
|
||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||
# that are scheduled since eviction takes place at can_allocate().
|
||||
assert self.num_free_slots >= num_encoder_tokens
|
||||
assert self.num_freeable_slots >= num_encoder_tokens
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
self.num_freeable_slots -= num_encoder_tokens
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
|
||||
Returns the set of input IDs whose `mm_hash` exists in the cache map.
|
||||
This includes entries that are currently unreferenced (and thus present
|
||||
in `freeable`); for such entries, freeing for this request will be a
|
||||
no-op.
|
||||
"""
|
||||
return {
|
||||
input_id
|
||||
for input_id in range(len(request.mm_features))
|
||||
if request.mm_features[input_id].identifier in self.cached
|
||||
}
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
"""Free the request's reference to the encoder input (`mm_data`)
|
||||
|
||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||
increased by the number of encoder tokens for that input.
|
||||
|
||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||
`can_allocate`).
|
||||
"""
|
||||
req_id = request.request_id
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
# The mm_hash not in cache or the req_id set is empty
|
||||
if not self.cached.get(mm_hash, None):
|
||||
return
|
||||
self.cached[mm_hash].discard(req_id)
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.freeable[mm_hash] = num_tokens
|
||||
self.num_freeable_slots += num_tokens
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all encoder input cache reference held by *request*.
|
||||
|
||||
For each cached input ID, `free_encoder_input` is invoked.
|
||||
The data stays in memory until eviction is triggered by a future
|
||||
attempt allocation called by 'can_allocate'.
|
||||
|
||||
Typically called when a request is finished, cancelled, or aborted.
|
||||
"""
|
||||
input_ids = self.get_cached_input_ids(request).copy()
|
||||
for input_id in input_ids:
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_freed_mm_hashes(self) -> list[str]:
|
||||
"""Get and clear the list of recently freed encoder cache entries.
|
||||
|
||||
Returns:
|
||||
List of mm_hash strings that were actually evicted since the last
|
||||
call to be used by the scheduler to notify workers about which
|
||||
encoder outputs can be removed from their caches. The internal
|
||||
list is cleared after this call.
|
||||
"""
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
return freed
|
||||
|
||||
|
||||
def compute_encoder_budget(
|
||||
model_config: "ModelConfig",
|
||||
scheduler_config: "SchedulerConfig",
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, measured in number of tokens
|
||||
from the input sequence.
|
||||
- Space budget for encoder cache size, measured in number of tokens
|
||||
from the input sequence.
|
||||
"""
|
||||
if mm_registry.supports_multimodal_inputs(model_config):
|
||||
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
|
||||
model_config
|
||||
)
|
||||
|
||||
return compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
max_tokens_by_modality,
|
||||
)
|
||||
|
||||
return compute_text_encoder_budget(scheduler_config)
|
||||
|
||||
|
||||
def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations for a text-only model.
|
||||
|
||||
Args:
|
||||
scheduler_config: Scheduler configuration.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Space budget for encoder cache size, in unit of number of tokens
|
||||
in the input sequence.
|
||||
"""
|
||||
# Currently text-only encoder-decoder models are not supported
|
||||
return 0, 0
|
||||
|
||||
|
||||
def compute_mm_encoder_budget(
|
||||
scheduler_config: "SchedulerConfig",
|
||||
max_tokens_by_modality: Mapping[str, int],
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations for a multimodal model.
|
||||
|
||||
Args:
|
||||
scheduler_config: Scheduler configuration.
|
||||
max_tokens_by_modality: The maximum number of tokens for each
|
||||
non-text modality.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, measured in number of tokens
|
||||
from the input sequence.
|
||||
- Space budget for encoder cache size, measured in number of tokens
|
||||
from the input sequence.
|
||||
"""
|
||||
|
||||
if not max_tokens_by_modality:
|
||||
logger.warning(
|
||||
"All non-text modalities supported by the model have been "
|
||||
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
|
||||
"not be initialized."
|
||||
)
|
||||
return 0, 0
|
||||
|
||||
max_tokens_per_mm_item = max(max_tokens_by_modality.values())
|
||||
|
||||
if (
|
||||
scheduler_config.disable_chunked_mm_input
|
||||
and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens
|
||||
):
|
||||
raise ValueError(
|
||||
"Chunked MM input disabled but max_tokens_per_mm_item "
|
||||
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
|
||||
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
|
||||
"max_num_batched_tokens."
|
||||
)
|
||||
|
||||
encoder_compute_budget = max(
|
||||
scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item
|
||||
)
|
||||
encoder_cache_size = max(
|
||||
scheduler_config.encoder_cache_size, max_tokens_per_mm_item
|
||||
)
|
||||
|
||||
return encoder_compute_budget, encoder_cache_size
|
||||
480
v1/core/kv_cache_coordinator.py
Normal file
480
v1/core/kv_cache_coordinator.py
Normal file
@@ -0,0 +1,480 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
CrossAttentionManager,
|
||||
FullAttentionManager,
|
||||
get_manager_for_kv_cache_spec,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class KVCacheCoordinator(ABC):
|
||||
"""
|
||||
Coordinate the KV cache of different KV cache groups.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.block_pool = BlockPool(
|
||||
kv_cache_config.num_blocks, enable_caching, enable_kv_cache_events
|
||||
)
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
self.single_type_managers = tuple(
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
|
||||
)
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
|
||||
num_encoder_tokens: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
num_blocks_to_allocate = 0
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
if isinstance(manager, CrossAttentionManager):
|
||||
# For cross-attention, we issue a single static allocation
|
||||
# of blocks based on the number of encoder input tokens.
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_encoder_tokens, []
|
||||
)
|
||||
else:
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i]
|
||||
)
|
||||
return num_blocks_to_allocate
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
|
||||
) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
manager.save_new_computed_blocks(request_id, new_computed_blocks[i])
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
return tuple(
|
||||
manager.allocate_new_blocks(
|
||||
request_id,
|
||||
num_encoder_tokens
|
||||
if isinstance(manager, CrossAttentionManager)
|
||||
else num_tokens,
|
||||
)
|
||||
for manager in self.single_type_managers
|
||||
)
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
num_computed_tokens: The total number of tokens
|
||||
that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.free(request_id)
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
"""
|
||||
Get the number of common prefix blocks for all requests with allocated
|
||||
KV cache for each kv cache group.
|
||||
|
||||
Args:
|
||||
running_request_id: The request ID of any running request, used to
|
||||
identify the common prefix blocks.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for each kv cache group.
|
||||
"""
|
||||
return [
|
||||
manager.get_num_common_prefix_blocks(running_request_id)
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and replace
|
||||
the removed blocks with null_block.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
|
||||
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the blocks for the request.
|
||||
"""
|
||||
return tuple(
|
||||
manager.req_to_blocks.get(request_id) or []
|
||||
for manager in self.single_type_managers
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator to use if prefix caching is disabled or unsupported.
|
||||
In contrast to UnitaryKVCacheCoordinator and HybridKVCacheCoordinator,
|
||||
supports arbitrary numbers of KV cache groups (including 0 groups).
|
||||
Does not implement any features related to prefix caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
False,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
return [0] * self.num_single_type_manager
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(self.num_single_type_manager)
|
||||
)
|
||||
return blocks, 0
|
||||
|
||||
|
||||
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for models with only one KV cache group. This is the
|
||||
case for models with only one KV cache type, e.g., all attention layers use
|
||||
full attention or all attention layers use sliding window attention.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
if dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||
)
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
|
||||
class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for hybrid models with multiple KV cache types, and
|
||||
thus multiple kv cache groups.
|
||||
To simplify `find_longest_cache_hit`, it only supports the combination of
|
||||
two types of KV cache groups, and one of them must be full attention.
|
||||
May extend to more general cases in the future.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
"""
|
||||
Verifies that the model has exactly two types of KV cache groups, and
|
||||
one of them is full attention. Then, split the kv cache groups into full
|
||||
attention groups and other groups.
|
||||
"""
|
||||
full_attention_spec: FullAttentionSpec | None = None
|
||||
other_spec: KVCacheSpec | None = None
|
||||
self.full_attention_group_ids: list[int] = []
|
||||
self.other_group_ids: list[int] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
if full_attention_spec is None:
|
||||
full_attention_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert full_attention_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||
"full attention groups now."
|
||||
)
|
||||
self.full_attention_group_ids.append(i)
|
||||
else:
|
||||
if other_spec is None:
|
||||
other_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert other_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes "
|
||||
"exactly one other type of groups now."
|
||||
)
|
||||
self.other_group_ids.append(i)
|
||||
|
||||
assert full_attention_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||
"attention groups now."
|
||||
)
|
||||
assert other_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of other groups now."
|
||||
)
|
||||
|
||||
self.full_attention_manager_cls = FullAttentionManager
|
||||
self.other_attention_cls = self.single_type_managers[
|
||||
self.other_group_ids[0]
|
||||
].__class__
|
||||
self.full_attention_spec = full_attention_spec
|
||||
self.other_spec = other_spec
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
|
||||
if self.enable_caching:
|
||||
# this requirement is only needed for the prefix caching logic
|
||||
divisible = self.other_block_size % self.full_attention_block_size
|
||||
assert divisible == 0, (
|
||||
"KVCacheCoordinator assumes the block_size of full "
|
||||
"attention layers is divisible by other layers now."
|
||||
)
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
|
||||
self.full_attn_first = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"HybridKVCacheCoordinator assumes the full "
|
||||
"attention group ids and other attention group ids "
|
||||
"do not interleave, either full attention group ids "
|
||||
"are before other attention group ids or vice versa."
|
||||
"This is for simplifying merging hit_blocks_full_attn and "
|
||||
"hit_blocks_other_attn to hit_blocks."
|
||||
)
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
"""
|
||||
Find the longest cache hit for the request.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_cache_hit_length: The maximum length of the cache hit.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of the cache hit blocks for each single type manager.
|
||||
- The number of tokens of the longest cache hit.
|
||||
"""
|
||||
# First, find the longest cache hit for full attention.
|
||||
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=self.full_attention_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.full_attention_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||
|
||||
# Next, find the cache hit for the other attention WITHIN
|
||||
# the cache hit of full attention.
|
||||
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=hit_length,
|
||||
kv_cache_group_ids=self.other_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.other_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
)
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
# NOTE: the prefix cache hit length must be a multiple of block_size as
|
||||
# we don't support partial block cache hit yet. The cache hit length
|
||||
# of other attention is ensured to be a multiple of the block size of
|
||||
# full attention layers in current implementation, because hit_length is
|
||||
# a multiple of other attention's block size, and other attention's
|
||||
# block size is a multiple of full attention's block size (verified in
|
||||
# `verify_and_split_kv_cache_groups`).
|
||||
assert hit_length % self.full_attention_block_size == 0
|
||||
|
||||
# Truncate the full attention cache hit to the length of the
|
||||
# cache hit of the other attention.
|
||||
for group_hit_blocks in hit_blocks_full_attn:
|
||||
del group_hit_blocks[hit_length // self.full_attention_block_size :]
|
||||
|
||||
# Merge the hit blocks of full attention and other attention.
|
||||
if self.full_attn_first:
|
||||
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
|
||||
else:
|
||||
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
|
||||
return hit_blocks, hit_length
|
||||
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
return HybridKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
420
v1/core/kv_cache_manager.py
Normal file
420
v1/core/kv_cache_manager.py
Normal file
@@ -0,0 +1,420 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, overload
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
"""
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
structure from the Scheduler.
|
||||
"""
|
||||
|
||||
blocks: tuple[Sequence[KVCacheBlock], ...]
|
||||
"""
|
||||
`blocks[i][j]` refers to the i-th kv_cache_group
|
||||
and the j-th block of tokens.We don't use block of
|
||||
tokens as the outer dimension because it assumes all
|
||||
kv_cache_groups have the same number of blocks, which is true for now but
|
||||
will be broken if we want to give different block_size to different
|
||||
kv_cache_groups in the future.
|
||||
|
||||
Each single type KVCacheBlocks could be represented as:
|
||||
- list[KVCacheBlock] for more than one KVCacheBlock
|
||||
- an empty tuple for requests without KVCacheBlock
|
||||
(a precomputed KVCacheBlocks is in KVCacheManager to avoid GC overhead)
|
||||
"""
|
||||
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(
|
||||
tuple(
|
||||
list(itertools.chain(blk1, blk2))
|
||||
for blk1, blk2 in zip(self.blocks, other.blocks)
|
||||
)
|
||||
)
|
||||
|
||||
@overload
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: Literal[False] = False,
|
||||
) -> tuple[list[int], ...]: ...
|
||||
|
||||
@overload
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: Literal[True] = True,
|
||||
) -> tuple[list[int], ...] | None: ...
|
||||
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: bool = False,
|
||||
) -> tuple[list[int], ...] | None:
|
||||
"""
|
||||
Converts the KVCacheBlocks instance to block_ids.
|
||||
|
||||
Returns:
|
||||
tuple[list[int], ...]: A tuple of lists where:
|
||||
- the outer tuple corresponds to KV cache groups
|
||||
- each inner list contains the block_ids of the blocks in that
|
||||
group
|
||||
"""
|
||||
if allow_none and all(len(group) == 0 for group in self.blocks):
|
||||
return None
|
||||
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
assert len(self.blocks) == 1, "Only one group is supported"
|
||||
return [block.block_id for block in self.blocks[0] if block.block_hash is None]
|
||||
|
||||
def new_empty(self) -> "KVCacheBlocks":
|
||||
"""
|
||||
Creates a new KVCacheBlocks instance with no blocks.
|
||||
"""
|
||||
return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))))
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
enable_caching: bool = True,
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
dcp_world_size: int = 1,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
|
||||
self.block_size: int | None = None
|
||||
if self.enable_caching:
|
||||
assert (
|
||||
len(
|
||||
set(
|
||||
g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
)
|
||||
)
|
||||
== 1
|
||||
), "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0
|
||||
].kv_cache_spec.block_size
|
||||
|
||||
if dcp_world_size > 1:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
# Note(hc): need revisit. When both DCP and any future
|
||||
# PCP are enabled, the block_size may need to be scaled
|
||||
# by a factor of dcp_size × pcp_size?
|
||||
self.block_size *= dcp_world_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_caching=self.enable_caching,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Pre-constructed KVCacheBlocks with no blocks, callers should use this
|
||||
# via create_kv_cache_blocks instead of creating new ones to avoid GC
|
||||
# overhead.
|
||||
#
|
||||
# We use nested tuples to ensure the empty KVCacheBlocks is immutable.
|
||||
self.empty_kv_cache_blocks = KVCacheBlocks(
|
||||
tuple(() for _ in range(self.num_kv_cache_groups))
|
||||
)
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return self.block_pool.get_usage()
|
||||
|
||||
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:
|
||||
"""Get the computed (cached) blocks for the request.
|
||||
Note that the computed blocks must be full.
|
||||
|
||||
Args:
|
||||
request: The request to get the computed blocks.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of blocks that are computed for the request.
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
# We skip finding the prefix cache hit when prefix caching is
|
||||
# disabled or the request is marked as skipping kv cache read
|
||||
# (which happens when the request requires prompt logprobs
|
||||
# or calls a pooling model with all pooling).
|
||||
if not self.enable_caching or request.skip_reading_prefix_cache:
|
||||
return self.empty_kv_cache_blocks, 0
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
# the single last token, because allocate_slots() requires
|
||||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(
|
||||
request.block_hashes, max_cache_hit_length
|
||||
)
|
||||
)
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.record(
|
||||
num_tokens=request.num_tokens,
|
||||
num_hits=num_new_computed_tokens,
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_new_tokens: int,
|
||||
num_new_computed_tokens: int = 0,
|
||||
new_computed_blocks: KVCacheBlocks | None = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
num_encoder_tokens: int = 0,
|
||||
) -> KVCacheBlocks | None:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
Args:
|
||||
request: The request to allocate slots.
|
||||
num_new_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
num_new_computed_tokens: The number of new computed tokens just
|
||||
hitting the prefix caching, excluding external tokens.
|
||||
new_computed_blocks: The cached blocks for the above new computed
|
||||
tokens.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||
used by P/D when allocating blocks used in a KV transfer
|
||||
which will complete in a future step.
|
||||
|
||||
Blocks layout:
|
||||
```
|
||||
-----------------------------------------------------------------------
|
||||
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||||
-----------------------------------------------------------------------
|
||||
| < required > |
|
||||
--------------------------------------------------
|
||||
| < full > |
|
||||
------------------------------------------------
|
||||
| <new full> |
|
||||
--------------
|
||||
```
|
||||
The following *_blocks are illustrated in this layout.
|
||||
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
if num_new_tokens == 0:
|
||||
raise ValueError("num_new_tokens must be greater than 0")
|
||||
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = self.empty_kv_cache_blocks.blocks
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.coordinator.remove_skipped_blocks(
|
||||
request.request_id, request.num_computed_tokens
|
||||
)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not any(new_computed_block_list), (
|
||||
"Computed blocks should be empty when prefix caching is disabled"
|
||||
)
|
||||
|
||||
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list
|
||||
)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot, num_encoder_tokens
|
||||
)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
# remote. Update state for locally cached blocks.
|
||||
if not self.enable_caching or delay_cache_blocks:
|
||||
return self.create_kv_cache_blocks(new_blocks)
|
||||
|
||||
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
|
||||
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
|
||||
# draft tokens that could be rejected). Therefore, we cap the number
|
||||
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
|
||||
num_tokens_to_cache = min(
|
||||
num_computed_tokens + num_new_tokens, request.num_tokens
|
||||
)
|
||||
self.coordinator.cache_blocks(request, num_tokens_to_cache)
|
||||
|
||||
return self.create_kv_cache_blocks(new_blocks)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free the blocks allocated for the request.
|
||||
We free the blocks in reverse order so that the tail blocks are evicted
|
||||
first when caching is enabled.
|
||||
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalidate prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
if not self.block_pool.reset_prefix_cache():
|
||||
return False
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.reset = True
|
||||
return True
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
"""Calculate the number of common prefix blocks for each kv cache group.
|
||||
|
||||
The function selects a running request and iterates through its blocks.
|
||||
A block is considered a common prefix block if ALL requests with
|
||||
allocated KV cache share it (i.e., ref_cnt equals the number of entries
|
||||
in req_to_blocks).
|
||||
|
||||
NOTE(woosuk): The number of requests with allocated KV cache is **greater
|
||||
than or equal to** the number of requests scheduled in the current step.
|
||||
This is because having allocated KV cache only indicates that:
|
||||
1. The request has not yet finished, and
|
||||
2. The request holds its blocks unfreed.
|
||||
|
||||
While all scheduled requests must have allocated KV cache, the inverse
|
||||
is not necessarily true. There may be requests with allocated KV cache
|
||||
that are not scheduled in the current step.
|
||||
|
||||
This can result in an edge case where the number of common prefix blocks
|
||||
is 0, even though all scheduled requests share a common prefix. This
|
||||
occurs because there may be unscheduled requests that do not share the
|
||||
common prefix. Currently, this case cannot be easily detected, so the
|
||||
function returns 0 in such cases.
|
||||
|
||||
Args:
|
||||
running_request_id: The request ID of any running request, used to
|
||||
identify the common prefix blocks.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for each kv cache
|
||||
group.
|
||||
"""
|
||||
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_blocks(self, request_id: str) -> KVCacheBlocks:
|
||||
"""Get the blocks of a request."""
|
||||
return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id))
|
||||
|
||||
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||||
"""Get the block ids of a request."""
|
||||
return self.get_blocks(request_id).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
if self.enable_caching:
|
||||
self.coordinator.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def create_kv_cache_blocks(
|
||||
self, blocks: tuple[list[KVCacheBlock], ...]
|
||||
) -> KVCacheBlocks:
|
||||
# Only create new KVCacheBlocks for non-empty blocks
|
||||
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|
||||
1356
v1/core/kv_cache_utils.py
Normal file
1356
v1/core/kv_cache_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
0
v1/core/sched/__init__.py
Normal file
0
v1/core/sched/__init__.py
Normal file
BIN
v1/core/sched/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/core/sched/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/sched/__pycache__/async_scheduler.cpython-312.pyc
Normal file
BIN
v1/core/sched/__pycache__/async_scheduler.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/sched/__pycache__/interface.cpython-312.pyc
Normal file
BIN
v1/core/sched/__pycache__/interface.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/sched/__pycache__/output.cpython-312.pyc
Normal file
BIN
v1/core/sched/__pycache__/output.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/sched/__pycache__/request_queue.cpython-312.pyc
Normal file
BIN
v1/core/sched/__pycache__/request_queue.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/sched/__pycache__/scheduler.cpython-312.pyc
Normal file
BIN
v1/core/sched/__pycache__/scheduler.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/core/sched/__pycache__/utils.cpython-312.pyc
Normal file
BIN
v1/core/sched/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
62
v1/core/sched/async_scheduler.py
Normal file
62
v1/core/sched/async_scheduler.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncScheduler(Scheduler):
|
||||
def _update_after_schedule(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
pending_structured_output_tokens = False
|
||||
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
pending_structured_output_tokens |= (
|
||||
request.use_structured_output and request.num_output_placeholders > 0
|
||||
)
|
||||
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens
|
||||
+ request.num_output_placeholders
|
||||
+ cur_num_spec_tokens
|
||||
):
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add placeholders for the new tokens in spec_token_ids.
|
||||
# Wwe will update the actual spec token ids in the worker process.
|
||||
request.spec_token_ids = [-1] * self.num_spec_tokens
|
||||
|
||||
scheduler_output.pending_structured_output_tokens = (
|
||||
pending_structured_output_tokens
|
||||
)
|
||||
|
||||
def _update_request_with_output(
|
||||
self,
|
||||
request: Request,
|
||||
new_token_ids: list[int],
|
||||
) -> tuple[list[int], bool]:
|
||||
status_before_update = request.status
|
||||
new_token_ids, stopped = super()._update_request_with_output(
|
||||
request, new_token_ids
|
||||
)
|
||||
|
||||
# Update the number of output placeholders.
|
||||
request.num_output_placeholders -= len(new_token_ids)
|
||||
assert request.num_output_placeholders >= 0
|
||||
|
||||
# Cache the new tokens. Preempted requests should be skipped.
|
||||
if status_before_update == RequestStatus.RUNNING:
|
||||
self.kv_cache_manager.cache_blocks(
|
||||
request, request.num_computed_tokens - request.num_output_placeholders
|
||||
)
|
||||
return new_token_ids, stopped
|
||||
181
v1/core/sched/interface.py
Normal file
181
v1/core/sched/interface.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
|
||||
class SchedulerInterface(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
structured_output_manager: "StructuredOutputManager",
|
||||
block_size: int,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
"""Schedule the requests to process in this scheduling step.
|
||||
|
||||
The scheduling decision is made at the iteration level. Each scheduling
|
||||
step corresponds to a single forward pass of the model. Therefore, this
|
||||
method is called repeatedly by a busy loop in the engine.
|
||||
|
||||
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
|
||||
that specifies how many tokens to process for each request in this
|
||||
scheduling step. For example, num_tokens can be as large as the number
|
||||
of prompt tokens for new requests, or it can be 1 for the requests that
|
||||
are auto-regressively generating new tokens one by one. Otherwise, it
|
||||
can be somewhere in between in case of chunked prefills, prefix caching,
|
||||
speculative decoding, etc.
|
||||
|
||||
Additionally, the scheduler also returns useful data about each request
|
||||
or the batch as a whole. The model runner will use this information in
|
||||
preparing inputs to the model.
|
||||
|
||||
Returns:
|
||||
A SchedulerOutput object containing information about the scheduled
|
||||
requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_grammar_bitmask(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> "GrammarOutput | None":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> dict[int, "EngineCoreOutputs"]:
|
||||
"""Update the scheduler state based on the model runner output.
|
||||
|
||||
This method is called after the model runner has processed the scheduled
|
||||
requests. The model runner output includes generated token ids, draft
|
||||
token ids for next step, etc. The scheduler uses this information to
|
||||
update its states, checks the finished requests, and returns the output
|
||||
for each request.
|
||||
|
||||
Returns:
|
||||
A dict of client index to EngineCoreOutputs object containing the
|
||||
outputs for each request originating from that client.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_draft_token_ids(
|
||||
self,
|
||||
draft_token_ids: "DraftTokenIds",
|
||||
) -> None:
|
||||
"""Update the draft token ids for the scheduled requests."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, request: "Request") -> None:
|
||||
"""Add a new request to the scheduler's internal queue.
|
||||
|
||||
Args:
|
||||
request: The new request being added.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: str | Iterable[str],
|
||||
finished_status: "RequestStatus",
|
||||
) -> None:
|
||||
"""Finish the requests in the scheduler's internal queue. If the request
|
||||
is not in the queue, this method will do nothing.
|
||||
|
||||
This method is called in two cases:
|
||||
1. When the request is aborted by the client.
|
||||
2. When the frontend process detects a stop string of the request after
|
||||
de-tokenizing its generated tokens.
|
||||
|
||||
Args:
|
||||
request_ids: A single or a list of request IDs.
|
||||
finished_status: The finished status of the given requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Number of unfinished requests in the scheduler's internal queue."""
|
||||
raise NotImplementedError
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests in the scheduler's
|
||||
internal queue."""
|
||||
return self.get_num_unfinished_requests() > 0
|
||||
|
||||
@abstractmethod
|
||||
def has_finished_requests(self) -> bool:
|
||||
"""Returns True if there are finished requests that need to be cleared.
|
||||
NOTE: This is different from `not self.has_unfinished_requests()`.
|
||||
|
||||
The scheduler maintains an internal list of the requests finished in the
|
||||
previous step. This list is returned from the next call to schedule(),
|
||||
to be sent to the model runner in the next step to clear cached states
|
||||
for these finished requests.
|
||||
|
||||
This method checks if this internal list of finished requests is
|
||||
non-empty. This information is useful for DP attention.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def has_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests, or finished requests
|
||||
not yet returned in SchedulerOutputs."""
|
||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset the prefix cache for KV cache.
|
||||
|
||||
This is particularly required when the model weights are live-updated.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self) -> Optional["SchedulerStats"]:
|
||||
"""Make a SchedulerStats object for logging.
|
||||
|
||||
The SchedulerStats object is created for every scheduling step.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
|
||||
return None
|
||||
202
v1/core/sched/output.py
Normal file
202
v1/core/sched/output.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
ECConnectorMetadata = object
|
||||
KVConnectorMetadata = object
|
||||
LoRARequest = object
|
||||
MultiModalFeatureSpec = object
|
||||
PoolingParams = object
|
||||
SamplingParams = object
|
||||
Request = object
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
req_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: LoRARequest | None
|
||||
prompt_embeds: "torch.Tensor | None" = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> "NewRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
|
||||
return (
|
||||
f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")"
|
||||
)
|
||||
|
||||
# Version of __repr__ with the prompt data obfuscated
|
||||
def anon_repr(self) -> str:
|
||||
prompt_token_ids_len = (
|
||||
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
|
||||
)
|
||||
prompt_embeds_shape = self.prompt_embeds.shape if self.prompt_embeds else None
|
||||
return (
|
||||
f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={prompt_token_ids_len},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
req_ids: list[str]
|
||||
# For request ids not in resumed_req_ids, new_block_ids will be appended to
|
||||
# the request's block IDs. For those in the set, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_req_ids: set[str]
|
||||
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
||||
# When PP is not used, new_token_ids will be empty.
|
||||
new_token_ids: list[list[int]]
|
||||
# For requests not scheduled in the last step, propagate the token ids to the
|
||||
# connector. Won't contain requests that were scheduled in the prior step.
|
||||
all_token_ids: dict[str, list[int]]
|
||||
new_block_ids: list[tuple[list[int], ...] | None]
|
||||
num_computed_tokens: list[int]
|
||||
num_output_tokens: list[int]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_ids)
|
||||
|
||||
@cached_property
|
||||
@deprecated("use resumed_req_ids field")
|
||||
def resumed_from_preemption(self) -> list[bool]:
|
||||
return [req_id in self.resumed_req_ids for req_id in self.req_ids]
|
||||
|
||||
@cached_property
|
||||
@deprecated("use all_token_ids field")
|
||||
def resumed_req_token_ids(self) -> list[list[int] | None]:
|
||||
return [
|
||||
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
|
||||
for req_id in self.req_ids
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "CachedRequestData":
|
||||
return cls(
|
||||
req_ids=[],
|
||||
resumed_req_ids=set(),
|
||||
new_token_ids=[],
|
||||
all_token_ids={},
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=[],
|
||||
num_output_tokens=[],
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
# list of the requests that are scheduled for the first time.
|
||||
# We cache the request's data in each worker process, so that we don't
|
||||
# need to re-send it every scheduling step.
|
||||
scheduled_new_reqs: list[NewRequestData]
|
||||
# list of the requests that have been scheduled before.
|
||||
# Since the request's data is already cached in the worker processes,
|
||||
# we only send the diff to minimize the communication cost.
|
||||
scheduled_cached_reqs: CachedRequestData
|
||||
|
||||
# req_id -> num_scheduled_tokens
|
||||
# Number of tokens scheduled for each request.
|
||||
num_scheduled_tokens: dict[str, int]
|
||||
# Total number of tokens scheduled for all requests.
|
||||
# Equal to sum(num_scheduled_tokens.values())
|
||||
total_num_scheduled_tokens: int
|
||||
# req_id -> spec_token_ids
|
||||
# If a request does not have any spec decode tokens, it will not be
|
||||
# included in the dictionary.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]]
|
||||
# req_id -> encoder input indices that need processing.
|
||||
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
||||
# to process that the request's 0-th and 1-th images in the current step.
|
||||
scheduled_encoder_inputs: dict[str, list[int]]
|
||||
# Number of common prefix blocks for all requests in each KV cache group.
|
||||
# This can be used for cascade attention.
|
||||
num_common_prefix_blocks: list[int]
|
||||
|
||||
# Request IDs that are finished in between the previous and the current
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
# so that they can free the cached states for those requests.
|
||||
finished_req_ids: set[str]
|
||||
# list of mm_hash strings associated with the encoder outputs to be
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Whether the scheduled requests have all the output tokens they
|
||||
# need to perform grammar bitmask computation.
|
||||
pending_structured_output_tokens: bool = False
|
||||
|
||||
# KV Cache Connector metadata.
|
||||
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarOutput:
|
||||
# ids of structured output requests.
|
||||
structured_output_request_ids: list[str]
|
||||
# Bitmask ordered as structured_output_request_ids.
|
||||
grammar_bitmask: "npt.NDArray[np.int32]"
|
||||
221
v1/core/sched/request_queue.py
Normal file
221
v1/core/sched/request_queue.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from enum import Enum
|
||||
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class SchedulingPolicy(Enum):
|
||||
"""Enum for scheduling policies."""
|
||||
|
||||
FCFS = "fcfs"
|
||||
PRIORITY = "priority"
|
||||
|
||||
|
||||
class RequestQueue(ABC):
|
||||
"""Abstract base class for request queues."""
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the request at the front of the queue without removing it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Prepend a request to the front of the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepend_requests(self, requests: "RequestQueue") -> None:
|
||||
"""Prepend all requests from another queue to the front of this
|
||||
queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse order."""
|
||||
pass
|
||||
|
||||
|
||||
class FCFSRequestQueue(deque[Request], RequestQueue):
|
||||
"""A first-come-first-served queue that supports deque operations."""
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to FCFS policy."""
|
||||
self.append(request)
|
||||
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to FCFS policy."""
|
||||
return self.popleft()
|
||||
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the next request in the queue without removing it."""
|
||||
if not self:
|
||||
raise IndexError("peek from an empty queue")
|
||||
return self[0]
|
||||
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Prepend a request to the front of the queue."""
|
||||
self.appendleft(request)
|
||||
|
||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||
"""Prepend all requests from another queue to the front of this
|
||||
queue."""
|
||||
self.extendleft(reversed(requests))
|
||||
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
self.remove(request)
|
||||
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
requests_to_remove = set(requests)
|
||||
filtered_requests = [req for req in self if req not in requests_to_remove]
|
||||
# deque does not support in-place filtering, so we need to clear
|
||||
# and extend
|
||||
self.clear()
|
||||
self.extend(filtered_requests)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
return len(self) > 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
return super().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to FCFS policy."""
|
||||
return super().__iter__()
|
||||
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse order."""
|
||||
return super().__reversed__()
|
||||
|
||||
|
||||
class PriorityRequestQueue(RequestQueue):
|
||||
"""
|
||||
A priority queue that supports heap operations.
|
||||
|
||||
Requests with a smaller value of `priority` are processed first.
|
||||
If multiple requests have the same priority, the one with the earlier
|
||||
`arrival_time` is processed first.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._heap: list[tuple[int, float, Request]] = []
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to priority policy."""
|
||||
heapq.heappush(self._heap, (request.priority, request.arrival_time, request))
|
||||
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to priority policy."""
|
||||
if not self._heap:
|
||||
raise IndexError("pop from empty heap")
|
||||
_, _, request = heapq.heappop(self._heap)
|
||||
return request
|
||||
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the next request in the queue without removing it."""
|
||||
if not self._heap:
|
||||
raise IndexError("peek from empty heap")
|
||||
_, _, request = self._heap[0]
|
||||
return request
|
||||
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to priority policy.
|
||||
|
||||
Note: In a priority queue, there is no concept of prepending to the
|
||||
front. Requests are ordered by (priority, arrival_time)."""
|
||||
self.add_request(request)
|
||||
|
||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||
"""Add all requests from another queue according to priority policy.
|
||||
|
||||
Note: In a priority queue, there is no concept of prepending to the
|
||||
front. Requests are ordered by (priority, arrival_time)."""
|
||||
for request in requests:
|
||||
self.add_request(request)
|
||||
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
self._heap = [(p, t, r) for p, t, r in self._heap if r != request]
|
||||
heapq.heapify(self._heap)
|
||||
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
requests_to_remove = set(requests)
|
||||
self._heap = [
|
||||
(p, t, r) for p, t, r in self._heap if r not in requests_to_remove
|
||||
]
|
||||
heapq.heapify(self._heap)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
return bool(self._heap)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
return len(self._heap)
|
||||
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to priority policy."""
|
||||
heap_copy = self._heap[:]
|
||||
while heap_copy:
|
||||
_, _, request = heapq.heappop(heap_copy)
|
||||
yield request
|
||||
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse priority order."""
|
||||
return reversed(list(self))
|
||||
|
||||
|
||||
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
|
||||
"""Create request queue based on scheduling policy."""
|
||||
if policy == SchedulingPolicy.PRIORITY:
|
||||
return PriorityRequestQueue()
|
||||
elif policy == SchedulingPolicy.FCFS:
|
||||
return FCFSRequestQueue()
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduling policy: {policy}")
|
||||
1617
v1/core/sched/scheduler.py
Normal file
1617
v1/core/sched/scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
72
v1/core/sched/utils.py
Normal file
72
v1/core/sched/utils.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def remove_all(lst: list, items_to_remove: set) -> list:
|
||||
"""Remove all items from a list that are in the items_to_remove set.
|
||||
|
||||
This method optimizes for the common case of removing a single item,
|
||||
falling back to list comprehension for multiple items.
|
||||
|
||||
Args:
|
||||
lst: The list to remove items from
|
||||
items_to_remove: Set of items to remove
|
||||
|
||||
Returns:
|
||||
Either the modified original list (for single item removal) or
|
||||
a new list (for multiple item removal). Callers should use the
|
||||
returned value.
|
||||
|
||||
Note:
|
||||
For single item removal, this modifies the original list in-place
|
||||
and returns it. For multiple items, it creates and returns a new list.
|
||||
"""
|
||||
if not items_to_remove:
|
||||
return lst
|
||||
|
||||
if len(items_to_remove) == 1:
|
||||
# Fast path for single item removal (most common case)
|
||||
item = next(iter(items_to_remove))
|
||||
with contextlib.suppress(ValueError):
|
||||
lst.remove(item)
|
||||
return lst
|
||||
# For multiple items, use list comprehension
|
||||
return [item for item in lst if item not in items_to_remove]
|
||||
|
||||
|
||||
def check_stop(
|
||||
request: Request, max_model_len: int, pooler_output: torch.Tensor | None = None
|
||||
) -> bool:
|
||||
if request.pooling_params:
|
||||
if pooler_output is not None:
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
return False
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
|
||||
if request.num_output_tokens < sampling_params.min_tokens:
|
||||
return False
|
||||
|
||||
last_token_id = request.output_token_ids[-1]
|
||||
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
request.stop_reason = last_token_id
|
||||
return True
|
||||
if (
|
||||
request.num_tokens >= max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
return False
|
||||
736
v1/core/single_type_kv_cache_manager.py
Normal file
736
v1/core/single_type_kv_cache_manager.py
Normal file
@@ -0,0 +1,736 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class SingleTypeKVCacheManager(ABC):
|
||||
"""
|
||||
An abstract base class for a manager that handle the kv cache management
|
||||
logic of one specific type of attention layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SingleTypeKVCacheManager.
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
"""
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
if self.dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
# is finished.
|
||||
self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
|
||||
|
||||
# {req_id: The number of cached blocks for this given request}
|
||||
# This is used to track the number of cached blocks for each request.
|
||||
# This is only used to track the RUNNING requests, we do not track the
|
||||
# data for preempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = (
|
||||
num_required_blocks
|
||||
- len(new_computed_blocks)
|
||||
- len(self.req_to_blocks[request_id])
|
||||
)
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
|
||||
)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||
) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
if request_id not in self.num_cached_block:
|
||||
# A new request.
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
assert len(req_blocks) == 0
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
self.num_cached_block[request_id] = len(new_computed_blocks)
|
||||
else:
|
||||
# A running request. Should not have new computed blocks.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
|
||||
num_full_blocks = num_tokens // self.block_size
|
||||
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
return
|
||||
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=self.req_to_blocks[request.request_id],
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
ordered_blocks = reversed(req_blocks)
|
||||
|
||||
self.block_pool.free_blocks(ordered_blocks)
|
||||
self.num_cached_block.pop(request_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
Get the number of common prefix blocks for all requests with allocated
|
||||
KV cache.
|
||||
|
||||
Args:
|
||||
running_request_id: The request ID.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks for all requests with allocated
|
||||
KV cache.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
`max_length`. The prefix should be a common prefix hit for all the
|
||||
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
|
||||
return an empty list.
|
||||
If eagle is enabled, drop the last matched block to force recompute the
|
||||
last block to get the required hidden states for eagle drafting head.
|
||||
Need to be customized for each attention type.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block
|
||||
for each kv cache group in `kv_cache_group_ids`.
|
||||
Return a list of length `len(kv_cache_group_ids)`, where the i-th
|
||||
element is a list of cached blocks for the i-th kv cache group
|
||||
in `kv_cache_group_ids`.
|
||||
For example, sliding window manager should return a list like
|
||||
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
|
||||
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove and free the blocks that are no longer needed for attention computation.
|
||||
The removed blocks should be replaced by null_block.
|
||||
|
||||
This function depends on `get_num_skipped_tokens`, which need to be implemented
|
||||
differently for each attention type.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
# Remove the blocks that will be skipped during attention computation.
|
||||
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
|
||||
if num_skipped_tokens <= 0:
|
||||
# This indicates that ALL tokens are inside attention window.
|
||||
# Thus we do not need to free any blocks outside attention window.
|
||||
# A typical case is full attention that we never free any token
|
||||
# before the request is finished.
|
||||
return
|
||||
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
# Because the block starts from index 0, the num_skipped_block-th block
|
||||
# corresponds to index num_skipped_blocks - 1.
|
||||
for i in range(num_skipped_blocks - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
# The default behavior is to not skip any tokens.
|
||||
return 0
|
||||
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
), (
|
||||
"FullAttentionManager can only be used for full attention "
|
||||
"and chunked local attention groups"
|
||||
)
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
block_size = kv_cache_spec.block_size
|
||||
if dcp_world_size > 1:
|
||||
block_size *= dcp_world_size
|
||||
max_num_blocks = max_length // block_size
|
||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids
|
||||
):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
if use_eagle and computed_blocks[0]:
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
blocks = self.req_to_blocks[running_request_id]
|
||||
num_common_blocks = 0
|
||||
for block in blocks:
|
||||
if block.ref_cnt == len(self.req_to_blocks):
|
||||
num_common_blocks += 1
|
||||
else:
|
||||
break
|
||||
return num_common_blocks
|
||||
|
||||
|
||||
class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
def __init__(
|
||||
self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
sliding_window_contiguous_blocks = cdiv(
|
||||
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size
|
||||
)
|
||||
if use_eagle:
|
||||
# Need to drop the last matched block if eagle is enabled. For
|
||||
# sliding window layer, we achieve this by increasing the number of
|
||||
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||
# the last matched block.
|
||||
sliding_window_contiguous_blocks += 1
|
||||
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(max_num_blocks) to
|
||||
# O(max_num_blocks / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
computed_blocks = tuple(
|
||||
[block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
num_contiguous_blocks = 0
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed[i] = cached
|
||||
num_contiguous_blocks += 1
|
||||
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
for computed in computed_blocks:
|
||||
del computed[i + num_contiguous_blocks :]
|
||||
match_found = True
|
||||
break
|
||||
else:
|
||||
num_contiguous_blocks = 0
|
||||
if not match_found:
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
for computed in computed_blocks:
|
||||
del computed[num_contiguous_blocks:]
|
||||
if use_eagle and computed_blocks[0]:
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
For sliding window, this corresponds to the tokens that are prior to
|
||||
the current sliding window.
|
||||
|
||||
Example:
|
||||
sliding_window=4, num_computed_tokens=7
|
||||
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 ]
|
||||
| ---- computed -----|
|
||||
^ next token to be computed
|
||||
|-----------| sliding window for next token
|
||||
|--skipped---|
|
||||
|
||||
The current window contains tokens 4~7. Tokens 0~3 will be skipped for
|
||||
attention computation since they are outside the sliding window.
|
||||
Thus, get_num_skipped_tokens(7) == 4.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
return num_computed_tokens - self.sliding_window + 1
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
|
||||
So it's not correct to count ref_cnt like FullAttentionManager. Return
|
||||
0 here for correctness. Need to support cascade attention + sliding
|
||||
window in the future.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
def __init__(
|
||||
self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
For chunked local attention, we need to find the longest cache hit
|
||||
prefix of the blocks that is not longer than `max_length`. The prefix
|
||||
should be a common prefix hit for all the kv cache groups in
|
||||
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
|
||||
note we mark as computed if the whole block is outside of the local
|
||||
window, and set the block as null. Examples:
|
||||
|
||||
1. Attention chunk size of 8, block size of 4, max length of 15
|
||||
for next token at 15th (zero-indexed), 8th - 14th tokens are in
|
||||
the window(needs lookup), 0th - 7th are not in the window,
|
||||
so they are already marked as computed. We check the complete
|
||||
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
|
||||
[null, null, block 3], otherwise, we return [null, null]
|
||||
|
||||
2. Attention chunk size of 8, block size of 4, max length of 16
|
||||
for next token at 16th (zero-indexed), 0th - 15th tokens are not
|
||||
in the window, so they are already marked as computed.
|
||||
we return 4 blocks[null, null, null, null]
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks
|
||||
"""
|
||||
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
|
||||
"ChunkedLocalAttentionManager can only be used for "
|
||||
+ "chunked local attention groups"
|
||||
)
|
||||
assert use_eagle is False, (
|
||||
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (
|
||||
max_length
|
||||
// kv_cache_spec.attention_chunk_size
|
||||
* kv_cache_spec.attention_chunk_size
|
||||
)
|
||||
else:
|
||||
local_attention_start_idx = 0
|
||||
# we marked blocks out of window as computed
|
||||
# with null blocks, and blocks inside window based on cache lookup
|
||||
# result [null] [null] ... [null] [hit block 1 (1st block contain
|
||||
# last window)] [hit block 2] ... [hit block x]
|
||||
local_attention_start_block_idx = (
|
||||
local_attention_start_idx // kv_cache_spec.block_size
|
||||
)
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[block_pool.null_block] * local_attention_start_block_idx
|
||||
for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
for i in range(local_attention_start_block_idx, max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids
|
||||
):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
return computed_blocks
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
For chunked local attention, this corresponds to the tokens that are on
|
||||
the left side of the current chunk.
|
||||
|
||||
Example 1:
|
||||
chunk size = 8, num_computed_tokens = 13
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
| ----- computed ---------------|
|
||||
^^ next token to be computed
|
||||
|----------------| <-- attention window for
|
||||
next token
|
||||
|--- skipped -----|
|
||||
Output: get_num_skipped_tokens(13) == 8
|
||||
|
||||
Example 2:
|
||||
chunk size = 8, num_computed_tokens = 8
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
| --- computed ---|
|
||||
^ next token to be computed
|
||||
|--| <-- attention window for next token
|
||||
| --- skipped ----|
|
||||
Output: get_num_skipped_tokens(8) == 8
|
||||
|
||||
Example 3:
|
||||
chunk size = 8, num_computed_tokens = 7
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
|---computed---|
|
||||
^ next token to be computed
|
||||
|-----------------| <-- attention window for next token
|
||||
no token should be skipped.
|
||||
Output: get_num_skipped_tokens(7) == 0
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
num_skipped_tokens = (
|
||||
num_computed_tokens // self.attention_chunk_size
|
||||
) * self.attention_chunk_size
|
||||
return num_skipped_tokens
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by chunked local attention.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class MambaManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, MambaSpec), (
|
||||
"MambaManager can only be used for mamba groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
# the hit length logic later assumes:
|
||||
# hit_length = len(hit_blocks_other_attn[0])
|
||||
# * self.other_block_size
|
||||
# so we insert dummy blocks at the beginning:
|
||||
computed.extend([block_pool.null_block] * i)
|
||||
computed.append(cached)
|
||||
break # we just need the last match - early stopping
|
||||
|
||||
return computed_blocks
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by mamba
|
||||
"""
|
||||
return 0
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
) -> int:
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||
num_tokens += (
|
||||
self.kv_cache_spec.block_size
|
||||
* self.kv_cache_spec.num_speculative_blocks
|
||||
)
|
||||
return super().get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks
|
||||
)
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
) -> list[KVCacheBlock]:
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||
num_tokens += (
|
||||
self.kv_cache_spec.block_size
|
||||
* self.kv_cache_spec.num_speculative_blocks
|
||||
)
|
||||
return super().allocate_new_blocks(request_id, num_tokens)
|
||||
|
||||
|
||||
class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
"""Manager for cross-attention KV cache in encoder-decoder models."""
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||
) -> None:
|
||||
# We do not cache blocks for cross-attention to be shared between
|
||||
# requests, so `new_computed_blocks` should always be empty.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
# We do not cache blocks for cross-attention to be shared between
|
||||
# requests, so this method is not relevant.
|
||||
raise ValueError("Should not be called as prefix caching is disabled.")
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
# Cross-attention blocks contain request-specific encoder states
|
||||
# and are not shared between different requests
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: list[BlockHash],
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||
"CrossAttentionManager can only be used for cross-attention groups"
|
||||
)
|
||||
# Cross-attention does not benefit from prefix caching since:
|
||||
# 1. Encoder states are unique per request (different audio/image
|
||||
# inputs)
|
||||
# 2. Encoder states are computed once per request, not incrementally
|
||||
# 3. No reusable prefix exists between different multimodal inputs
|
||||
# Return empty blocks to indicate no cache hits
|
||||
raise NotImplementedError("CrossAttentionManager does not support caching")
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
CrossAttentionSpec: CrossAttentionManager,
|
||||
}
|
||||
|
||||
|
||||
def get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec: KVCacheSpec, **kwargs
|
||||
) -> SingleTypeKVCacheManager:
|
||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||
manager = manager_class(kv_cache_spec, **kwargs)
|
||||
return manager
|
||||
148
v1/cudagraph_dispatcher.py
Normal file
148
v1/cudagraph_dispatcher.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from itertools import product
|
||||
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
from vllm.forward_context import BatchDescriptor
|
||||
|
||||
|
||||
class CudagraphDispatcher:
|
||||
"""
|
||||
Runtime cudagraph dispatcher to dispatch keys for multiple set of
|
||||
cudagraphs.
|
||||
|
||||
The dispatcher stores two sets of dispatch keys, one for PIECEWISE and one
|
||||
for FULL cudagraph runtime mode. The keys are initialized depending on
|
||||
attention support and what cudagraph mode is set in CompilationConfig. The
|
||||
keys stored in dispatcher are the only source of truth for valid
|
||||
cudagraphs that can be dispatched at runtime.
|
||||
|
||||
At runtime, the dispatch method generates the runtime cudagraph mode (FULL,
|
||||
PIECEWISE, or NONE for no cudagraph) and the valid key (batch descriptor)
|
||||
based on the input key. After dispatching (communicated via forward
|
||||
context), the cudagraph wrappers will trust the dispatch key to either
|
||||
capture or replay (if the mode matches), or pass through to the underlying
|
||||
runnable without cudagraph (if the mode does not match or mode is NONE).
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.cudagraph_mode = self.compilation_config.cudagraph_mode
|
||||
|
||||
# Dict to store valid cudagraph dispatching keys.
|
||||
self.cudagraph_keys: dict[CUDAGraphMode, set[BatchDescriptor]] = {
|
||||
CUDAGraphMode.PIECEWISE: set(),
|
||||
CUDAGraphMode.FULL: set(),
|
||||
}
|
||||
|
||||
not_use_piecewise_compilation = (
|
||||
not self.cudagraph_mode.requires_piecewise_compilation()
|
||||
)
|
||||
|
||||
assert (
|
||||
not_use_piecewise_compilation
|
||||
or self.compilation_config.is_attention_compiled_piecewise()
|
||||
), (
|
||||
"Compilation mode should be CompilationMode.VLLM_COMPILE when "
|
||||
"cudagraph_mode piecewise cudagraphs is used, "
|
||||
"and attention should be in splitting_ops or "
|
||||
"inductor splitting should be used. "
|
||||
f"cudagraph_mode={self.cudagraph_mode}, "
|
||||
f"compilation_mode={self.compilation_config.mode}, "
|
||||
f"splitting_ops={self.compilation_config.splitting_ops}"
|
||||
)
|
||||
|
||||
self.keys_initialized = False
|
||||
|
||||
def add_cudagraph_key(
|
||||
self, runtime_mode: CUDAGraphMode, batch_descriptor: BatchDescriptor
|
||||
):
|
||||
assert runtime_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], (
|
||||
f"Invalid cudagraph runtime mode for keys: {runtime_mode}"
|
||||
)
|
||||
self.cudagraph_keys[runtime_mode].add(batch_descriptor)
|
||||
|
||||
def initialize_cudagraph_keys(
|
||||
self, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int
|
||||
):
|
||||
# This should be called only after attention backend is initialized.
|
||||
|
||||
# LoRA activation cases to specialize the cuda graphs on
|
||||
if self.vllm_config.lora_config:
|
||||
if self.compilation_config.cudagraph_specialize_lora:
|
||||
lora_cases = [True, False]
|
||||
else:
|
||||
lora_cases = [True]
|
||||
else:
|
||||
lora_cases = [False]
|
||||
|
||||
# Note: we create all valid keys for cudagraph here but do not
|
||||
# guarantee all keys would be used. For example, if we allow lazy
|
||||
# capturing in future PR, some keys may never be triggered.
|
||||
if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE:
|
||||
for bs, has_lora in product(
|
||||
self.compilation_config.cudagraph_capture_sizes, lora_cases
|
||||
):
|
||||
self.add_cudagraph_key(
|
||||
cudagraph_mode.mixed_mode(),
|
||||
BatchDescriptor(
|
||||
num_tokens=bs, uniform_decode=False, has_lora=has_lora
|
||||
),
|
||||
)
|
||||
|
||||
# if decode cudagraph mode is FULL, and we don't already have mixed
|
||||
# mode full cudagraphs then add them here.
|
||||
if (
|
||||
cudagraph_mode.decode_mode() == CUDAGraphMode.FULL
|
||||
and cudagraph_mode.separate_routine()
|
||||
):
|
||||
max_num_tokens = (
|
||||
uniform_decode_query_len
|
||||
* self.vllm_config.scheduler_config.max_num_seqs
|
||||
)
|
||||
cudagraph_capture_sizes_for_decode = [
|
||||
x
|
||||
for x in self.compilation_config.cudagraph_capture_sizes
|
||||
if x <= max_num_tokens and x >= uniform_decode_query_len
|
||||
]
|
||||
for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases):
|
||||
self.add_cudagraph_key(
|
||||
CUDAGraphMode.FULL,
|
||||
BatchDescriptor(
|
||||
num_tokens=bs, uniform_decode=True, has_lora=has_lora
|
||||
),
|
||||
)
|
||||
self.keys_initialized = True
|
||||
|
||||
def dispatch(
|
||||
self, batch_descriptor: BatchDescriptor, use_cascade_attn: bool = False
|
||||
) -> tuple[CUDAGraphMode, BatchDescriptor | None]:
|
||||
"""
|
||||
Given conditions(e.g.,batch descriptor and if using cascade attention),
|
||||
dispatch to a cudagraph runtime mode and the valid batch descriptor.
|
||||
A new batch descriptor is returned as we might dispatch a uniform batch
|
||||
to a graph that supports a more general batch (uniform to non-uniform).
|
||||
"""
|
||||
# if not initialized, just skip dispatching.
|
||||
if not self.keys_initialized:
|
||||
return CUDAGraphMode.NONE, None
|
||||
|
||||
non_uniform_key = batch_descriptor.non_uniform
|
||||
# if a batch use cascade attention, bypass checking full cudagraphs
|
||||
if not use_cascade_attn:
|
||||
# check if key exists for full cudagraph
|
||||
if batch_descriptor in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, batch_descriptor
|
||||
|
||||
# otherwise, check if non-uniform key exists
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.FULL]:
|
||||
return CUDAGraphMode.FULL, non_uniform_key
|
||||
|
||||
# also check if non-uniform key exists for more "general"
|
||||
# piecewise cudagraph
|
||||
if non_uniform_key in self.cudagraph_keys[CUDAGraphMode.PIECEWISE]:
|
||||
return CUDAGraphMode.PIECEWISE, non_uniform_key
|
||||
|
||||
# finally, just return no cudagraphs
|
||||
return CUDAGraphMode.NONE, None
|
||||
206
v1/engine/__init__.py
Normal file
206
v1/engine/__init__.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import LogprobsLists, LogprobsTensors
|
||||
from vllm.v1.serial_utils import UtilityResult
|
||||
|
||||
# These are possible values of RequestOutput.finish_reason,
|
||||
# so form part of the external API.
|
||||
FINISH_REASON_STRINGS = ("stop", "length", "abort")
|
||||
|
||||
|
||||
class FinishReason(enum.IntEnum):
|
||||
"""
|
||||
Reason a request finished - stop, length, or abort.
|
||||
|
||||
Int rather than Str for more compact serialization.
|
||||
|
||||
stop - a stop string was emitted
|
||||
length - max_tokens was consumed, or max_model_len was reached
|
||||
abort - aborted for another reason
|
||||
|
||||
"""
|
||||
|
||||
STOP = 0
|
||||
LENGTH = 1
|
||||
ABORT = 2
|
||||
|
||||
def __str__(self):
|
||||
return FINISH_REASON_STRINGS[self.value]
|
||||
|
||||
|
||||
class EngineCoreRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec] | None
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
eos_token_id: int | None
|
||||
arrival_time: float
|
||||
lora_request: LoRARequest | None
|
||||
cache_salt: str | None
|
||||
data_parallel_rank: int | None
|
||||
prompt_embeds: torch.Tensor | None = None
|
||||
|
||||
# Index of the client, used to ensure outputs are sent back to the same
|
||||
# client for this request when scaling out the front-end.
|
||||
client_index: int = 0
|
||||
|
||||
# Used in DP case to indicate which wave of requests this is expected to
|
||||
# belong to, to cover a race condition where the request is sent before
|
||||
# a wave finished notification is received.
|
||||
current_wave: int = 0
|
||||
priority: int = 0
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
|
||||
QUEUED = 1
|
||||
SCHEDULED = 2
|
||||
PREEMPTED = 3
|
||||
|
||||
|
||||
class EngineCoreEvent(msgspec.Struct):
|
||||
"""A timestamped engine core event associated with a request.
|
||||
|
||||
The timestamp is a monotonic timestamps and is used for by the engine
|
||||
frontend to calculate intervals between engine core events. These
|
||||
timestamps should not be compared with timestamps from other processes.
|
||||
"""
|
||||
|
||||
type: EngineCoreEventType
|
||||
timestamp: float
|
||||
|
||||
@classmethod
|
||||
def new_event(
|
||||
cls, event_type: EngineCoreEventType, timestamp: float | None = None
|
||||
) -> "EngineCoreEvent":
|
||||
timestamp = time.monotonic() if timestamp is None else timestamp
|
||||
return cls(event_type, timestamp)
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
request_id: str
|
||||
new_token_ids: list[int]
|
||||
|
||||
new_logprobs: LogprobsLists | None = None
|
||||
new_prompt_logprobs_tensors: LogprobsTensors | None = None
|
||||
|
||||
pooling_output: torch.Tensor | None = None
|
||||
|
||||
finish_reason: FinishReason | None = None
|
||||
stop_reason: int | str | None = None
|
||||
events: list[EngineCoreEvent] | None = None
|
||||
kv_transfer_params: dict[str, Any] | None = None
|
||||
|
||||
trace_headers: Mapping[str, str] | None = None
|
||||
# The number of tokens with prefix cache hits.
|
||||
num_cached_tokens: int = 0
|
||||
|
||||
# The number of NaNs in logits.
|
||||
# A value greater than 0 indicates that the output is corrupted.
|
||||
num_nans_in_logits: int = 0
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
call_id: int
|
||||
|
||||
# Non-None implies the call failed, result should be None.
|
||||
failure_message: str | None = None
|
||||
result: UtilityResult | None = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False,
|
||||
): # type: ignore[call-arg]
|
||||
# NOTE(Nick): We could consider ways to make this more compact,
|
||||
# e.g. columnwise layout
|
||||
|
||||
engine_index: int = 0
|
||||
|
||||
# [num_reqs]
|
||||
outputs: list[EngineCoreOutput] = []
|
||||
scheduler_stats: SchedulerStats | None = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: UtilityOutput | None = None
|
||||
finished_requests: set[str] | None = None
|
||||
|
||||
# In DP case, used to signal that the current wave of requests
|
||||
# has finished and the engines are paused.
|
||||
wave_complete: int | None = None
|
||||
# In DP case, used to signal that a request was received for an
|
||||
# "old" wave, so the next wave needs to be started in other engines.
|
||||
start_wave: int | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
Request types defined as hex byte strings, so it can be sent over sockets
|
||||
without separate encoding step.
|
||||
"""
|
||||
|
||||
ADD = b"\x00"
|
||||
ABORT = b"\x01"
|
||||
START_DP_WAVE = b"\x02"
|
||||
UTILITY = b"\x03"
|
||||
# Sentinel used within EngineCoreProc.
|
||||
EXECUTOR_FAILED = b"\x04"
|
||||
|
||||
|
||||
class ReconfigureDistributedRequest(msgspec.Struct):
|
||||
new_data_parallel_size: int
|
||||
new_data_parallel_rank: int
|
||||
new_data_parallel_rank_local: int
|
||||
new_data_parallel_master_ip: str
|
||||
new_data_parallel_master_port: int
|
||||
|
||||
|
||||
class ReconfigureRankType(enum.IntEnum):
|
||||
"""
|
||||
Rank type for reconfiguring distributed request.
|
||||
"""
|
||||
|
||||
KEEP_CURRENT_RANK = -1
|
||||
SHUTDOWN_CURRENT_RANK = -2
|
||||
BIN
v1/engine/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/engine/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/engine/__pycache__/async_llm.cpython-312.pyc
Normal file
BIN
v1/engine/__pycache__/async_llm.cpython-312.pyc
Normal file
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user