Sync from v0.13
This commit is contained in:
404
vllm/v1/kv_cache_interface.py
Normal file
404
vllm/v1/kv_cache_interface.py
Normal file
@@ -0,0 +1,404 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from math import prod
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import get_dtype_size
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KVCacheSpec:
|
||||
"""
|
||||
A base class for specifying the KV cache format of one layer.
|
||||
"""
|
||||
|
||||
# number of tokens in a block
|
||||
block_size: int
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
"""
|
||||
The size of a page with `block_size` tokens in bytes.
|
||||
|
||||
Returns:
|
||||
The page size
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
"""
|
||||
The maximum possible memory usage of this KV cache in bytes.
|
||||
|
||||
Returns:
|
||||
The KV cache size in bytes
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def copy_with_new_block_size(self, block_size: int) -> Self:
|
||||
"""
|
||||
Create a new KVCacheSpec from self but replacing the block size.
|
||||
"""
|
||||
return replace(self, block_size=block_size)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
|
||||
"""
|
||||
assert all(spec == specs[0] for spec in specs[1:]), (
|
||||
"All layers in the same KV cache group must be the same."
|
||||
)
|
||||
return copy.deepcopy(specs[0])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return (
|
||||
2
|
||||
* self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
sliding_window: int | None = None
|
||||
attention_chunk_size: int | None = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size * pcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> int | None:
|
||||
if len(window_sizes) == 0:
|
||||
return None
|
||||
elif len(window_sizes) == 1:
|
||||
return window_sizes.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All attention layers in the same KV cache group must have the "
|
||||
"same window size."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be FullAttentionSpec."
|
||||
)
|
||||
|
||||
sliding_window = set(
|
||||
spec.sliding_window for spec in specs if spec.sliding_window is not None
|
||||
)
|
||||
attention_chunk_size = set(
|
||||
spec.attention_chunk_size
|
||||
for spec in specs
|
||||
if spec.attention_chunk_size is not None
|
||||
)
|
||||
assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"MLAAttentionSpec should be merged in MLAAttentionSpec.merge"
|
||||
)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec."
|
||||
)
|
||||
assert (merged_spec.sliding_window is not None) + (
|
||||
merged_spec.attention_chunk_size is not None
|
||||
) <= 1, (
|
||||
"Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported."
|
||||
)
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MLAAttentionSpec(FullAttentionSpec):
|
||||
# TODO(Lucas/Chen): less hacky way to do this
|
||||
cache_dtype_str: str | None = None
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
if self.cache_dtype_str == "fp8_ds_mla":
|
||||
# See `vllm/v1/attention/backends/mla/flashmla_sparse.py`
|
||||
# for details.
|
||||
return self.block_size * 656
|
||||
return (
|
||||
self.block_size
|
||||
* self.num_kv_heads
|
||||
* self.head_size
|
||||
* get_dtype_size(self.dtype)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be MLAAttentionSpec."
|
||||
)
|
||||
cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs)
|
||||
assert len(cache_dtype_str_set) == 1, (
|
||||
"All attention layers in the same KV cache group must use the same "
|
||||
"quantization method."
|
||||
)
|
||||
return cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
cache_dtype_str=cache_dtype_str_set.pop(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChunkedLocalAttentionSpec(AttentionSpec):
|
||||
attention_chunk_size: int
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
# During chunked prefill, we allocate KV cache for at most
|
||||
# `self.attention_chunk_size` computed tokens plus the newly scheduled
|
||||
# tokens. And we won't allocate KV cache for more than `max_model_len`
|
||||
# tokens.
|
||||
num_tokens = min(
|
||||
self.attention_chunk_size + max_num_batched_tokens, max_model_len
|
||||
)
|
||||
|
||||
return cdiv(num_tokens, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
sliding_window: int
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
|
||||
"DCP not support sliding window."
|
||||
)
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
|
||||
# During chunked prefill, we allocate KV cache for the last
|
||||
# `self.sliding_window-1` computed tokens plus the newly scheduled
|
||||
# tokens. And we won't allocate KV cache for more than `max_model_len`
|
||||
# tokens.
|
||||
num_tokens = min(
|
||||
self.sliding_window - 1 + max_num_batched_tokens, max_model_len
|
||||
)
|
||||
|
||||
# +1 here because the sliding window may not start from the beginning
|
||||
# of the block. For example, if the block size is 4 and num_token
|
||||
# is 4, we need two blocks [XXCD] [EF] to store the sliding
|
||||
# window [CDEF] of 6 tokens.
|
||||
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MambaSpec(KVCacheSpec):
|
||||
shapes: tuple[tuple[int, ...], ...]
|
||||
dtypes: tuple[torch.dtype]
|
||||
page_size_padded: int | None = None
|
||||
mamba_type: str = "mamba2"
|
||||
num_speculative_blocks: int = 0
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
page_size = sum(
|
||||
prod(shape) * get_dtype_size(dtype)
|
||||
for (shape, dtype) in zip(self.shapes, self.dtypes)
|
||||
)
|
||||
if self.page_size_padded is not None:
|
||||
assert self.page_size_padded >= page_size
|
||||
return self.page_size_padded
|
||||
return page_size
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EncoderOnlyAttentionSpec(AttentionSpec):
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# Encoder-only layers do not need KV cache
|
||||
return 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CrossAttentionSpec(AttentionSpec):
|
||||
"""
|
||||
KV cache spec for cross-attention layers in encoder-decoder models.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
# For cross-attention, we need to cache encoder states
|
||||
# Get encoder length (e.g., 1500 for Whisper).
|
||||
max_encoder_len = vllm_config.scheduler_config.max_num_encoder_input_tokens
|
||||
return cdiv(max_encoder_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UniformTypeKVCacheSpecs(KVCacheSpec):
|
||||
"""
|
||||
A KV cache spec for multiple layers with the same type of attention. Here,
|
||||
same types means always need the same number of token slots. For example,
|
||||
sliding window attentions with different window sizes are not the same type
|
||||
and should not be merged into one UniformTypeKVCacheSpecs.
|
||||
"""
|
||||
|
||||
kv_cache_specs: dict[str, KVCacheSpec]
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
return sum(spec.page_size_bytes for spec in self.kv_cache_specs.values())
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_num_pages = max(
|
||||
cdiv(spec.max_memory_usage_bytes(vllm_config), spec.page_size_bytes)
|
||||
for spec in self.kv_cache_specs.values()
|
||||
)
|
||||
return max_num_pages * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool:
|
||||
"""
|
||||
Whether all layers have the same type of KV cache spec.
|
||||
"""
|
||||
block_sizes = set(spec.block_size for spec in kv_cache_specs.values())
|
||||
if len(block_sizes) > 1:
|
||||
# Different block sizes, not uniform.
|
||||
return False
|
||||
one_spec = next(iter(kv_cache_specs.values()))
|
||||
if isinstance(one_spec, FullAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, FullAttentionSpec) for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, CrossAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, SlidingWindowSpec):
|
||||
return all(
|
||||
isinstance(spec, SlidingWindowSpec)
|
||||
and spec.sliding_window == one_spec.sliding_window
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, ChunkedLocalAttentionSpec):
|
||||
return all(
|
||||
isinstance(spec, ChunkedLocalAttentionSpec)
|
||||
and spec.attention_chunk_size == one_spec.attention_chunk_size
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
elif isinstance(one_spec, MambaSpec):
|
||||
return all(
|
||||
isinstance(spec, MambaSpec)
|
||||
and spec.num_speculative_blocks == one_spec.num_speculative_blocks
|
||||
for spec in kv_cache_specs.values()
|
||||
)
|
||||
else:
|
||||
# NOTE(Chen): Please add new branches for new KV cache spec types.
|
||||
raise NotImplementedError(
|
||||
f"Unsupported KV cache spec type: {type(one_spec)}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_specs(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> Self | None:
|
||||
"""
|
||||
Return a SameTypeKVCacheSpecs object if all layers have the same type
|
||||
of KV cache spec. Return None if not.
|
||||
"""
|
||||
if cls.is_uniform_type(kv_cache_specs):
|
||||
block_size = next(iter(kv_cache_specs.values())).block_size
|
||||
return cls(block_size=block_size, kv_cache_specs=kv_cache_specs)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheTensor:
|
||||
"""
|
||||
A class for specifying how the workers should initialize the KV cache.
|
||||
"""
|
||||
|
||||
size: int # size of the KV cache tensor in bytes
|
||||
shared_by: list[str] # layer names that share the same KV cache tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheGroupSpec:
|
||||
"""
|
||||
Represents a group of model layers that share the same KV cache block table.
|
||||
These layers are regarded as one layer in the KV cache manager.
|
||||
"""
|
||||
|
||||
# The names of model layers in this group
|
||||
layer_names: list[str]
|
||||
# The KV cache spec of this manager layer
|
||||
kv_cache_spec: KVCacheSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheConfig:
|
||||
"""
|
||||
The KV cache configuration of a model.
|
||||
"""
|
||||
|
||||
"""The number of KV cache blocks"""
|
||||
num_blocks: int
|
||||
"""How should model runner initialize the KV cache tensors for each layer"""
|
||||
kv_cache_tensors: list[KVCacheTensor]
|
||||
"""
|
||||
The kv cache groups of the model.
|
||||
For models with only one type of attention, there is only one group that
|
||||
contains all layers.
|
||||
For models with multiple types of attention, there will be multiple groups,
|
||||
see `_get_kv_cache_config_uniform_page_size` for more details.
|
||||
"""
|
||||
kv_cache_groups: list[KVCacheGroupSpec]
|
||||
Reference in New Issue
Block a user