421 lines
17 KiB
Python
421 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
||
import itertools
|
||
from collections.abc import Sequence
|
||
from dataclasses import dataclass
|
||
from typing import Literal, overload
|
||
|
||
from vllm.distributed.kv_events import KVCacheEvent
|
||
from vllm.logger import init_logger
|
||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||
from vllm.v1.request import Request
|
||
|
||
logger = init_logger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class KVCacheBlocks:
|
||
"""
|
||
The allocation result of KVCacheManager, work as the interface between
|
||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||
structure from the Scheduler.
|
||
"""
|
||
|
||
blocks: tuple[Sequence[KVCacheBlock], ...]
|
||
"""
|
||
`blocks[i][j]` refers to the i-th kv_cache_group
|
||
and the j-th block of tokens.We don't use block of
|
||
tokens as the outer dimension because it assumes all
|
||
kv_cache_groups have the same number of blocks, which is true for now but
|
||
will be broken if we want to give different block_size to different
|
||
kv_cache_groups in the future.
|
||
|
||
Each single type KVCacheBlocks could be represented as:
|
||
- list[KVCacheBlock] for more than one KVCacheBlock
|
||
- an empty tuple for requests without KVCacheBlock
|
||
(a precomputed KVCacheBlocks is in KVCacheManager to avoid GC overhead)
|
||
"""
|
||
|
||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||
"""Adds two KVCacheBlocks instances."""
|
||
return KVCacheBlocks(
|
||
tuple(
|
||
list(itertools.chain(blk1, blk2))
|
||
for blk1, blk2 in zip(self.blocks, other.blocks)
|
||
)
|
||
)
|
||
|
||
@overload
|
||
def get_block_ids(
|
||
self,
|
||
allow_none: Literal[False] = False,
|
||
) -> tuple[list[int], ...]: ...
|
||
|
||
@overload
|
||
def get_block_ids(
|
||
self,
|
||
allow_none: Literal[True] = True,
|
||
) -> tuple[list[int], ...] | None: ...
|
||
|
||
def get_block_ids(
|
||
self,
|
||
allow_none: bool = False,
|
||
) -> tuple[list[int], ...] | None:
|
||
"""
|
||
Converts the KVCacheBlocks instance to block_ids.
|
||
|
||
Returns:
|
||
tuple[list[int], ...]: A tuple of lists where:
|
||
- the outer tuple corresponds to KV cache groups
|
||
- each inner list contains the block_ids of the blocks in that
|
||
group
|
||
"""
|
||
if allow_none and all(len(group) == 0 for group in self.blocks):
|
||
return None
|
||
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||
|
||
def get_unhashed_block_ids(self) -> list[int]:
|
||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||
assert len(self.blocks) == 1, "Only one group is supported"
|
||
return [block.block_id for block in self.blocks[0] if block.block_hash is None]
|
||
|
||
def new_empty(self) -> "KVCacheBlocks":
|
||
"""
|
||
Creates a new KVCacheBlocks instance with no blocks.
|
||
"""
|
||
return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))))
|
||
|
||
|
||
class KVCacheManager:
|
||
def __init__(
|
||
self,
|
||
kv_cache_config: KVCacheConfig,
|
||
max_model_len: int,
|
||
enable_caching: bool = True,
|
||
use_eagle: bool = False,
|
||
log_stats: bool = False,
|
||
enable_kv_cache_events: bool = False,
|
||
dcp_world_size: int = 1,
|
||
) -> None:
|
||
self.max_model_len = max_model_len
|
||
|
||
self.enable_caching = enable_caching
|
||
self.use_eagle = use_eagle
|
||
self.log_stats = log_stats
|
||
# FIXME: make prefix cache stats conditional on log_stats
|
||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||
|
||
self.block_size: int | None = None
|
||
if self.enable_caching:
|
||
assert (
|
||
len(
|
||
set(
|
||
g.kv_cache_spec.block_size
|
||
for g in kv_cache_config.kv_cache_groups
|
||
)
|
||
)
|
||
== 1
|
||
), "Only one block size is supported for now"
|
||
self.block_size = kv_cache_config.kv_cache_groups[
|
||
0
|
||
].kv_cache_spec.block_size
|
||
|
||
if dcp_world_size > 1:
|
||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||
# Note(hc): need revisit. When both DCP and any future
|
||
# PCP are enabled, the block_size may need to be scaled
|
||
# by a factor of dcp_size × pcp_size?
|
||
self.block_size *= dcp_world_size
|
||
|
||
self.coordinator = get_kv_cache_coordinator(
|
||
kv_cache_config=kv_cache_config,
|
||
max_model_len=self.max_model_len,
|
||
use_eagle=self.use_eagle,
|
||
enable_caching=self.enable_caching,
|
||
enable_kv_cache_events=enable_kv_cache_events,
|
||
dcp_world_size=dcp_world_size,
|
||
)
|
||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||
self.block_pool = self.coordinator.block_pool
|
||
self.kv_cache_config = kv_cache_config
|
||
|
||
# Pre-constructed KVCacheBlocks with no blocks, callers should use this
|
||
# via create_kv_cache_blocks instead of creating new ones to avoid GC
|
||
# overhead.
|
||
#
|
||
# We use nested tuples to ensure the empty KVCacheBlocks is immutable.
|
||
self.empty_kv_cache_blocks = KVCacheBlocks(
|
||
tuple(() for _ in range(self.num_kv_cache_groups))
|
||
)
|
||
|
||
@property
|
||
def usage(self) -> float:
|
||
"""Get the KV cache usage.
|
||
|
||
Returns:
|
||
The KV cache usage (between 0.0 and 1.0).
|
||
"""
|
||
return self.block_pool.get_usage()
|
||
|
||
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
|
||
"""Get (and reset) the prefix cache stats.
|
||
|
||
Returns:
|
||
The current prefix caching stats, or None if logging is disabled.
|
||
"""
|
||
if not self.log_stats:
|
||
return None
|
||
stats = self.prefix_cache_stats
|
||
self.prefix_cache_stats = PrefixCacheStats()
|
||
return stats
|
||
|
||
def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:
|
||
"""Get the computed (cached) blocks for the request.
|
||
Note that the computed blocks must be full.
|
||
|
||
Args:
|
||
request: The request to get the computed blocks.
|
||
|
||
Returns:
|
||
A tuple containing:
|
||
- A list of blocks that are computed for the request.
|
||
- The number of computed tokens.
|
||
"""
|
||
# We skip finding the prefix cache hit when prefix caching is
|
||
# disabled or the request is marked as skipping kv cache read
|
||
# (which happens when the request requires prompt logprobs
|
||
# or calls a pooling model with all pooling).
|
||
if not self.enable_caching or request.skip_reading_prefix_cache:
|
||
return self.empty_kv_cache_blocks, 0
|
||
|
||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||
# This can trigger recomputation of an entire block, rather than just
|
||
# the single last token, because allocate_slots() requires
|
||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||
# could slightly improve performance in the future.
|
||
max_cache_hit_length = request.num_tokens - 1
|
||
computed_blocks, num_new_computed_tokens = (
|
||
self.coordinator.find_longest_cache_hit(
|
||
request.block_hashes, max_cache_hit_length
|
||
)
|
||
)
|
||
|
||
if self.log_stats:
|
||
assert self.prefix_cache_stats is not None
|
||
self.prefix_cache_stats.record(
|
||
num_tokens=request.num_tokens,
|
||
num_hits=num_new_computed_tokens,
|
||
preempted=request.num_preemptions > 0,
|
||
)
|
||
|
||
return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
|
||
|
||
def allocate_slots(
|
||
self,
|
||
request: Request,
|
||
num_new_tokens: int,
|
||
num_new_computed_tokens: int = 0,
|
||
new_computed_blocks: KVCacheBlocks | None = None,
|
||
num_lookahead_tokens: int = 0,
|
||
delay_cache_blocks: bool = False,
|
||
num_encoder_tokens: int = 0,
|
||
) -> KVCacheBlocks | None:
|
||
"""Add slots for a request with new tokens to append.
|
||
|
||
Args:
|
||
request: The request to allocate slots.
|
||
num_new_tokens: The number of tokens to allocate, including external
|
||
tokens. Note that this does not include tokens that have
|
||
already been computed locally (i.e. new_computed_blocks).
|
||
num_new_computed_tokens: The number of new computed tokens just
|
||
hitting the prefix caching, excluding external tokens.
|
||
new_computed_blocks: The cached blocks for the above new computed
|
||
tokens.
|
||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||
This is used by spec decode proposers with kv-cache such
|
||
as eagle.
|
||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||
used by P/D when allocating blocks used in a KV transfer
|
||
which will complete in a future step.
|
||
|
||
Blocks layout:
|
||
```
|
||
-----------------------------------------------------------------------
|
||
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||
-----------------------------------------------------------------------
|
||
| < required > |
|
||
--------------------------------------------------
|
||
| < full > |
|
||
------------------------------------------------
|
||
| <new full> |
|
||
--------------
|
||
```
|
||
The following *_blocks are illustrated in this layout.
|
||
|
||
Returns:
|
||
A list of new allocated blocks.
|
||
"""
|
||
if num_new_tokens == 0:
|
||
raise ValueError("num_new_tokens must be greater than 0")
|
||
|
||
if new_computed_blocks is not None:
|
||
new_computed_block_list = new_computed_blocks.blocks
|
||
else:
|
||
new_computed_block_list = self.empty_kv_cache_blocks.blocks
|
||
|
||
# Free the blocks that are skipped during the attention computation
|
||
# (e.g., tokens outside the sliding window).
|
||
# We can do this even if we cannot schedule this request due to
|
||
# insufficient free blocks.
|
||
# Should call this function before allocating new blocks to reduce
|
||
# the number of evicted blocks.
|
||
self.coordinator.remove_skipped_blocks(
|
||
request.request_id, request.num_computed_tokens
|
||
)
|
||
|
||
# The number of computed tokens is the number of computed tokens plus
|
||
# the new prefix caching hits
|
||
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
|
||
num_tokens_need_slot = min(
|
||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||
self.max_model_len,
|
||
)
|
||
|
||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||
request_id=request.request_id,
|
||
num_tokens=num_tokens_need_slot,
|
||
new_computed_blocks=new_computed_block_list,
|
||
num_encoder_tokens=num_encoder_tokens,
|
||
)
|
||
|
||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||
# Cannot allocate new blocks
|
||
return None
|
||
|
||
# Touch the computed blocks to make sure they won't be evicted.
|
||
if self.enable_caching:
|
||
self.block_pool.touch(new_computed_block_list)
|
||
else:
|
||
assert not any(new_computed_block_list), (
|
||
"Computed blocks should be empty when prefix caching is disabled"
|
||
)
|
||
|
||
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
|
||
# Append the new computed blocks to the request blocks until now to
|
||
# avoid the case where the new blocks cannot be allocated.
|
||
self.coordinator.save_new_computed_blocks(
|
||
request.request_id, new_computed_block_list
|
||
)
|
||
|
||
new_blocks = self.coordinator.allocate_new_blocks(
|
||
request.request_id, num_tokens_need_slot, num_encoder_tokens
|
||
)
|
||
|
||
# P/D: delay caching blocks if we have to recv from
|
||
# remote. Update state for locally cached blocks.
|
||
if not self.enable_caching or delay_cache_blocks:
|
||
return self.create_kv_cache_blocks(new_blocks)
|
||
|
||
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
|
||
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
|
||
# draft tokens that could be rejected). Therefore, we cap the number
|
||
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
|
||
num_tokens_to_cache = min(
|
||
num_computed_tokens + num_new_tokens, request.num_tokens
|
||
)
|
||
self.coordinator.cache_blocks(request, num_tokens_to_cache)
|
||
|
||
return self.create_kv_cache_blocks(new_blocks)
|
||
|
||
def free(self, request: Request) -> None:
|
||
"""Free the blocks allocated for the request.
|
||
We free the blocks in reverse order so that the tail blocks are evicted
|
||
first when caching is enabled.
|
||
|
||
Args:
|
||
request: The request to free the blocks.
|
||
"""
|
||
self.coordinator.free(request.request_id)
|
||
|
||
def reset_prefix_cache(self) -> bool:
|
||
"""Reset prefix cache. This function may be used in RLHF
|
||
flows to invalidate prefix caching after the weights are updated,
|
||
or used for resetting prefix caching status for benchmarking.
|
||
|
||
Returns:
|
||
bool: True if the prefix cache is successfully reset,
|
||
False otherwise.
|
||
"""
|
||
if not self.block_pool.reset_prefix_cache():
|
||
return False
|
||
if self.log_stats:
|
||
assert self.prefix_cache_stats is not None
|
||
self.prefix_cache_stats.reset = True
|
||
return True
|
||
|
||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||
"""Calculate the number of common prefix blocks for each kv cache group.
|
||
|
||
The function selects a running request and iterates through its blocks.
|
||
A block is considered a common prefix block if ALL requests with
|
||
allocated KV cache share it (i.e., ref_cnt equals the number of entries
|
||
in req_to_blocks).
|
||
|
||
NOTE(woosuk): The number of requests with allocated KV cache is **greater
|
||
than or equal to** the number of requests scheduled in the current step.
|
||
This is because having allocated KV cache only indicates that:
|
||
1. The request has not yet finished, and
|
||
2. The request holds its blocks unfreed.
|
||
|
||
While all scheduled requests must have allocated KV cache, the inverse
|
||
is not necessarily true. There may be requests with allocated KV cache
|
||
that are not scheduled in the current step.
|
||
|
||
This can result in an edge case where the number of common prefix blocks
|
||
is 0, even though all scheduled requests share a common prefix. This
|
||
occurs because there may be unscheduled requests that do not share the
|
||
common prefix. Currently, this case cannot be easily detected, so the
|
||
function returns 0 in such cases.
|
||
|
||
Args:
|
||
running_request_id: The request ID of any running request, used to
|
||
identify the common prefix blocks.
|
||
|
||
Returns:
|
||
list[int]: The number of common prefix blocks for each kv cache
|
||
group.
|
||
"""
|
||
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
|
||
|
||
def take_events(self) -> list[KVCacheEvent]:
|
||
"""Take the KV cache events from the block pool.
|
||
|
||
Returns:
|
||
A list of KV cache events.
|
||
"""
|
||
return self.block_pool.take_events()
|
||
|
||
def get_blocks(self, request_id: str) -> KVCacheBlocks:
|
||
"""Get the blocks of a request."""
|
||
return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id))
|
||
|
||
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||
"""Get the block ids of a request."""
|
||
return self.get_blocks(request_id).get_block_ids()
|
||
|
||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||
"""Cache the blocks for the request, if enabled."""
|
||
if self.enable_caching:
|
||
self.coordinator.cache_blocks(request, num_computed_tokens)
|
||
|
||
def create_kv_cache_blocks(
|
||
self, blocks: tuple[list[KVCacheBlock], ...]
|
||
) -> KVCacheBlocks:
|
||
# Only create new KVCacheBlocks for non-empty blocks
|
||
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|