Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

0
vllm/v1/core/__init__.py Normal file
View File

485
vllm/v1/core/block_pool.py Normal file
View File

@@ -0,0 +1,485 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Sequence
from typing import Any
from vllm.distributed.kv_events import (
MEDIUM_GPU,
AllBlocksCleared,
BlockRemoved,
BlockStored,
KVCacheEvent,
)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
BlockHashWithGroupId,
ExternalBlockHash,
FreeKVCacheBlockQueue,
KVCacheBlock,
get_block_hash,
make_block_hash_with_group_id,
maybe_convert_block_hash,
)
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockHashToBlockMap:
"""
Cache of blocks that are used for prefix caching. It caches blocks
from hash directly to a block or multiple blocks
(i.e. {block_hash: KVCacheBlocks})
- Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks
would simply be a KVCacheBlock.
- Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock}
A cached block is a full block with a block hash that can be used
for prefix caching.
The cached block may be used by running requests or in the
free_block_queue that could potentially be evicted.
NOTE #1: We currently don't de-duplicate the blocks in the cache,
meaning that if a block becomes full and is cached, we don't check
if there is already an identical block in the cache. This is because
we want to make sure the allocated block IDs won't change so that
block tables are append-only.
NOTE #2: The union type is introduced in order to reduce GC costs
from the inner dict.
"""
def __init__(self):
self._cache: dict[
BlockHashWithGroupId, KVCacheBlock | dict[int, KVCacheBlock]
] = {}
def get_one_block(self, key: BlockHashWithGroupId) -> KVCacheBlock | None:
"""
Gets any block with the given block hash key.
"""
blocks = self._cache.get(key)
if blocks is not None:
if isinstance(blocks, KVCacheBlock):
return blocks
if isinstance(blocks, dict):
return next(iter(blocks.values()))
self._unexpected_blocks_type(blocks)
return None
def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None:
"""
Inserts the KVCacheBlock to the cache
"""
blocks = self._cache.get(key)
if blocks is None:
# When key is not found, attach a single block to the key
self._cache[key] = block
elif isinstance(blocks, KVCacheBlock):
# If there's a block with the same key, merge the original block
# and the new block into a dict
self._cache[key] = {blocks.block_id: blocks, block.block_id: block}
elif isinstance(blocks, dict):
# If it's already a dict, simply insert the block
blocks[block.block_id] = block
else:
self._unexpected_blocks_type(blocks)
def pop(self, key: BlockHashWithGroupId, block_id: int) -> KVCacheBlock | None:
"""
Checks if block_hash exists and pop block_id from the cache
"""
blocks = self._cache.pop(key, None)
if blocks is None:
# block_hash not found in the cache
return None
# TODO(Jialin): If key is found, block_id should always present
# in blocks. We currently keep the original behaviour for safety.
#
# Will add block_id == blocks.block_id assertion and
# use del blocks[block_id] instead as followup.
if isinstance(blocks, KVCacheBlock):
if blocks.block_id == block_id:
return blocks
# If the single block ID doesn't match, we should put the
# block back (it should happen rarely)
self._cache[key] = blocks
return None
if isinstance(blocks, dict):
# Try to pop block_id from the block dict, and if dict still
# contain blocks, put back to the cache.
block = blocks.pop(block_id, None)
if len(blocks) > 0:
self._cache[key] = blocks
return block
self._unexpected_blocks_type(blocks)
return None
def __len__(self) -> int:
return len(self._cache)
def _unexpected_blocks_type(self, blocks: Any) -> None:
raise AssertionError(f"Invalid KV cache block type {type(blocks)}")
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
hash_block_size: The block size of which the block hashes are computed.
The actual block size usually equals hash_block_size, but in cases
where different KV cache groups have different block sizes, the
actual block size can be a multiple of hash_block_size.
enable_kv_cache_events: Whether to enable kv cache events.
metrics_collector: Optional metrics collector for tracking block residency.
"""
def __init__(
self,
num_gpu_blocks: int,
enable_caching: bool,
hash_block_size: int,
enable_kv_cache_events: bool = False,
metrics_collector: KVCacheMetricsCollector | None = None,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
self.hash_block_size = hash_block_size
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
# Cache for block lookup
self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap()
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
self.null_block.is_null = True
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
self.metrics_collector = metrics_collector
def get_cached_block(
self, block_hash: BlockHash, kv_cache_group_ids: list[int]
) -> list[KVCacheBlock] | None:
"""Get the cached block by the block hash for each group in
`kv_cache_group_ids`, or None if cache miss for any group.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
kv_cache_group_ids: The ids of the KV cache groups.
Returns:
The cached blocks if exists, or None.
"""
cached_blocks = []
for group_id in kv_cache_group_ids:
block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, group_id
)
block = self.cached_block_hash_to_block.get_one_block(
block_hash_with_group_id
)
if not block:
return None
cached_blocks.append(block)
return cached_blocks
def cache_full_blocks(
self,
request: Request,
blocks: list[KVCacheBlock],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it updates the
metadata for each block and caching it in the
`cached_block_hash_to_block`.
The block hashes values are computed by the Request object immediately
when it is created and when new tokens are appended.
Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
"""
if num_cached_blocks >= num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(request.block_hashes) >= num_full_blocks
if block_size == self.hash_block_size:
# Common case.
block_hashes: BlockHashList = request.block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when
# different KV cache groups have different block sizes.
assert block_size % self.hash_block_size == 0
# Recalculate block_hashes at the granularity of block_size, using
# the original block_hashes (at the granularity of hash_block_size).
block_hashes = BlockHashListWithBlockSize(
request.block_hashes, self.hash_block_size, block_size
)
new_block_hashes = block_hashes[num_cached_blocks:]
new_hashes: list[ExternalBlockHash] | None = (
[] if self.enable_kv_cache_events else None
)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
block_hash = new_block_hashes[i]
# Update and added the full block to the cache.
block_hash_with_group_id = make_block_hash_with_group_id(
block_hash, kv_cache_group_id
)
blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)
if new_hashes is not None:
new_hashes.append(maybe_convert_block_hash(block_hash))
if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash: ExternalBlockHash | None = None
else:
parent_block = blocks[num_cached_blocks - 1]
assert parent_block.block_hash is not None
parent_block_hash = maybe_convert_block_hash(
get_block_hash(parent_block.block_hash)
)
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
token_ids=request.all_token_ids[
num_cached_blocks * block_size : num_full_blocks * block_size
],
block_size=block_size,
lora_id=request.lora_request.adapter_id
if request.lora_request
else None,
medium=MEDIUM_GPU,
)
)
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
Returns:
A list of new block.
"""
if num_blocks > self.get_num_free_blocks():
raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
# In order to only iterate the list once, we duplicated code a bit
if self.enable_caching:
for block in ret:
self._maybe_evict_cached_block(block)
assert block.ref_cnt == 0
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_allocated(block)
else:
for block in ret:
assert block.ref_cnt == 0
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_allocated(block)
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to evict.
Returns:
True if the block is evicted, False otherwise.
"""
# Clean up metrics tracking first to prevent leaks
if self.metrics_collector:
self.metrics_collector.on_block_evicted(block)
block_hash = block.block_hash
if block_hash is None:
# The block doesn't have hash, eviction is not needed
return False
if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None:
# block not found in cached_block_hash_to_block,
# eviction is not needed
return False
block.reset_hash()
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(
block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))],
medium=MEDIUM_GPU,
)
)
return True
def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for blocks_per_group in blocks:
for block in blocks_per_group:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.ref_cnt += 1
if self.metrics_collector:
self.metrics_collector.on_block_accessed(block)
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their
eviction priority, where the first block will be evicted first.
Args:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
# Materialize the iterable to allow multiple passes.
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
self.free_block_queue.append_n(
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
only evicts blocks that are currently cached (have a hash). blocks
with ref_cnt > 0 are not freed from the block pool, only evicted
from the prefix cache hash table.
Args:
block_ids: Set of block IDs to evict from cache.
"""
for block_id in block_ids:
assert block_id < len(self.blocks), (
f"Invalid block_id {block_id} >= {len(self.blocks)}. "
f"This indicates a bug in the KV connector - workers should "
f"only report block IDs that were allocated by the scheduler."
)
block = self.blocks[block_id]
self._maybe_evict_cached_block(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet",
num_used_blocks - 1,
)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = BlockHashToBlockMap()
# Remove all hashes from all blocks.
for block in self.blocks:
block.reset_hash()
if self.metrics_collector:
self.metrics_collector.reset()
logger.info("Successfully reset prefix cache")
if self.enable_kv_cache_events:
self.kv_event_queue.append(AllBlocksCleared())
return True
def get_num_free_blocks(self) -> int:
"""Get the number of free blocks in the pool.
Returns:
The number of free blocks.
"""
return self.free_block_queue.num_free_blocks
def get_usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
# Subtract 1 to account for null block.
total_gpu_blocks = self.num_gpu_blocks - 1
if not total_gpu_blocks:
return 0
return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks)
def take_events(self) -> list[KVCacheEvent]:
"""Atomically takes all events and clears the queue.
Returns:
A list of KV cache events.
"""
if not self.enable_kv_cache_events:
return []
events = self.kv_event_queue
self.kv_event_queue = []
return events

View File

@@ -0,0 +1,402 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from collections.abc import Mapping
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig
logger = init_logger(__name__)
class EncoderCacheManager:
"""Manages caching of encoder outputs for multimodal models in vLLM V1.
The EncoderCacheManager handles the lifecycle of multimodal encoder outputs
(such as vision embeddings from images) during request processing. It
provides memory-aware caching to avoid recomputing encoder outputs when the
same multimodal inputs appear in different stages of request processing.
This manager is particularly important for:
- Vision-language models (e.g., LLaVA) where image encoder outputs are
cached
- Any multimodal model where encoder computation is expensive and
cacheable
The cache operates at the granularity of individual multimodal input items
within requests, allowing for fine-grained memory management and enabling
chunked processing of multimodal inputs.
Cache is enabled to share embeddings of same multimodal data
item (identified by their hash value) between different requests,
and eviction takes place at allocation time when there's no free
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args:
cache_size: Limit the size of the cache, measured by the number of
encoder embeddings from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder embeddings.
num_free_slots: Current available cache capacity in encoder embeddings.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder embeddings).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
last call to get_freed_mm_hashes(). This list is cleared on return.
"""
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
self.num_freeable_slots = cache_size
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
"""Check if encoder output for a specific multimodal input is cached.
If the encoder output is cached, update `cached` to add the request id
to the set of request ids that reference the cached encoder output.
If the encoder output was previously not referenced by any request,
update `freeable` and `num_freeable_slots` accordingly.
Args:
request: The request containing the multimodal input
input_id: Index of the multimodal input within the request
Returns:
True if the encoder output for this input is already cached
"""
mm_hash = request.mm_features[input_id].identifier
# Not cached at all
if mm_hash not in self.cached:
return False
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_encoder_embeds = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id)
return True
def can_allocate(
self,
request: Request,
input_id: int,
encoder_compute_budget: int,
num_embeds_to_schedule: int,
) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
If there is not enough free space in `num_free_slots` but there is
enough reclaimable space in `num_freeable_slots`, entries will be
evicted from `freeable` (their mm_hash appended to `freed`) until
enough space is available, and then this method returns True.
Older entries are evicted first.
Returns False only if the requested number of tokens exceeds both
the free and reclaimable capacities combined.
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder embeddings allowed to be
computed when this method is invoked.
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
allocated with cache space when this method is invoked.
Returns:
True if there's enough capacity to hold the encoder output for this
input (possibly after reclaiming `freeable` entries); otherwise
False.
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_embeds > encoder_compute_budget:
return False
num_embeds += num_embeds_to_schedule
# Enough free slots
if num_embeds <= self.num_free_slots:
return True
# Not enough reclaimable slots
if num_embeds > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_embeds > self.num_free_slots:
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_embeds
return True
def allocate(self, request: Request, input_id: int) -> None:
"""Allocate cache space for a multimodal input's encoder output.
This reserves cache space for storing the encoder output of the
specified multimodal input. The actual encoder output storage happens in
the model runner; this method updates the manager's bookkeeping.
Note:
This method assumes can_allocate() returned True for the same input.
"""
mm_hash = request.mm_features[input_id].identifier
request_id = request.request_id
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_embeds
assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
Returns the set of input IDs whose `mm_hash` exists in the cache map.
This includes entries that are currently unreferenced (and thus present
in `freeable`); for such entries, freeing for this request will be a
no-op.
"""
return {
input_id
for input_id in range(len(request.mm_features))
if request.mm_features[input_id].identifier in self.cached
}
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free the request's reference to the encoder input (`mm_data`)
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder embeddings for that input.
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
"""
req_id = request.request_id
mm_hash = request.mm_features[input_id].identifier
# The mm_hash not in cache or the req_id set is empty
if not self.cached.get(mm_hash, None):
return
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.freeable[mm_hash] = num_encoder_embeds
self.num_freeable_slots += num_encoder_embeds
def free(self, request: Request) -> None:
"""Free all encoder input cache reference held by *request*.
For each cached input ID, `free_encoder_input` is invoked.
The data stays in memory until eviction is triggered by a future
attempt allocation called by 'can_allocate'.
Typically called when a request is finished, cancelled, or aborted.
"""
input_ids = self.get_cached_input_ids(request).copy()
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_mm_hashes(self) -> list[str]:
"""Get and clear the list of recently freed encoder cache entries.
Returns:
List of mm_hash strings that were actually evicted since the last
call to be used by the scheduler to notify workers about which
encoder outputs can be removed from their caches. The internal
list is cleared after this call.
"""
freed = self.freed
self.freed = []
return freed
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Returns:
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if mm_registry.supports_multimodal_inputs(model_config):
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
model_config
)
return compute_mm_encoder_budget(
scheduler_config,
max_tokens_by_modality,
)
return compute_text_encoder_budget(scheduler_config)
def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a text-only model.
Args:
scheduler_config: Scheduler configuration.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
# Currently text-only encoder-decoder models are not supported
return 0, 0
def compute_mm_encoder_budget(
scheduler_config: "SchedulerConfig",
max_tokens_by_modality: Mapping[str, int],
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
Args:
scheduler_config: Scheduler configuration.
max_tokens_by_modality: The maximum number of tokens for each
non-text modality.
Returns:
- Compute budget for encoder execution, measured in number of tokens
from the input sequence.
- Space budget for encoder cache size, measured in number of tokens
from the input sequence.
"""
if not max_tokens_by_modality:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized."
)
return 0, 0
max_tokens_per_mm_item = max(max_tokens_by_modality.values())
if (
scheduler_config.disable_chunked_mm_input
and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens
):
raise ValueError(
"Chunked MM input disabled but max_tokens_per_mm_item "
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
"max_num_batched_tokens."
)
encoder_compute_budget = max(
scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item
)
encoder_cache_size = max(
scheduler_config.encoder_cache_size, max_tokens_per_mm_item
)
return encoder_compute_budget, encoder_cache_size
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
# use the manager for scheduling purposes. Encoder-decoder models will eventually
# utilize the cache and this class will fold into EncoderCacheManager, as
# differences with MM models shrink.
class EncoderDecoderCacheManager(EncoderCacheManager):
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
self.freed: list[str] = []
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
return False
def can_allocate(
self,
request: Request,
input_id: int,
encoder_compute_budget: int,
num_embeds_to_schedule: int,
) -> bool:
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_encoder_embeds > encoder_compute_budget:
return False
num_encoder_embeds += num_embeds_to_schedule
# Enough free slots
return num_encoder_embeds <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots -= num_encoder_embeds
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
def free(self, request: Request) -> None:
for input_id in range(len(request.mm_features)):
self.free_encoder_input(request, input_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
return set(range(len(request.mm_features)))
def get_freed_mm_hashes(self) -> list[str]:
freed = self.freed
self.freed = []
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots += num_encoder_embeds

View File

@@ -0,0 +1,570 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from math import lcm
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.kv_cache_utils import (
BlockHash,
BlockHashList,
BlockHashListWithBlockSize,
KVCacheBlock,
)
from vllm.v1.core.single_type_kv_cache_manager import (
CrossAttentionManager,
FullAttentionManager,
get_manager_for_kv_cache_spec,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheConfig,
KVCacheSpec,
)
from vllm.v1.request import Request
class KVCacheCoordinator(ABC):
"""
Coordinate the KV cache of different KV cache groups.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.block_pool = BlockPool(
kv_cache_config.num_blocks,
enable_caching,
hash_block_size,
enable_kv_cache_events,
metrics_collector,
)
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
self.single_type_managers = tuple(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
)
def get_num_blocks_to_allocate(
self,
request_id: str,
num_tokens: int,
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
num_encoder_tokens: int,
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
The number of blocks.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
if isinstance(manager, CrossAttentionManager):
# For cross-attention, we issue a single static allocation
# of blocks based on the number of encoder input tokens.
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_encoder_tokens, []
)
else:
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i]
)
return num_blocks_to_allocate
def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id, new_computed_blocks[i])
def allocate_new_blocks(
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
num_encoder_tokens: The number of encoder tokens for allocating
blocks for cross-attention.
Returns:
The new allocated blocks.
"""
return tuple(
manager.allocate_new_blocks(
request_id,
num_encoder_tokens
if isinstance(manager, CrossAttentionManager)
else num_tokens,
)
for manager in self.single_type_managers
)
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
num_computed_tokens: The total number of tokens
that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
for manager in self.single_type_managers:
manager.free(request_id)
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
"""
Get the number of common prefix blocks for all requests with allocated
KV cache for each kv cache group.
Args:
running_request_id: The request ID of any running request, used to
identify the common prefix blocks.
Returns:
list[int]: The number of common prefix blocks for each kv cache group.
"""
return [
manager.get_num_common_prefix_blocks(running_request_id)
for manager in self.single_type_managers
]
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
"""
return tuple(
manager.req_to_blocks.get(request_id) or []
for manager in self.single_type_managers
)
@abstractmethod
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
pass
class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""
KV cache coordinator to use if prefix caching is disabled or unsupported.
In contrast to UnitaryKVCacheCoordinator and HybridKVCacheCoordinator,
supports arbitrary numbers of KV cache groups (including 0 groups).
Does not implement any features related to prefix caching.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
kv_cache_config,
max_model_len,
use_eagle,
False,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
self.num_single_type_manager = len(self.single_type_managers)
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
return [0] * self.num_single_type_manager
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(self.num_single_type_manager)
)
return blocks, 0
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for models with only one KV cache group. This is the
case for models with only one KV cache type, e.g., all attention layers use
full attention or all attention layers use sliding window attention.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
self.pcp_world_size = pcp_world_size
if dcp_world_size > 1:
self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
# For models using only Mamba, block_size is set to max_model_len when
# prefix caching is disabled, and hash_block_size validation is skipped.
assert not enable_caching or (hash_block_size == self.block_size), (
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
)
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.block_size,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
class HybridKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
):
super().__init__(
kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
# hash_block_size: the block size used to compute block hashes.
# The actual block size usually equals hash_block_size, but in cases where
# different KV cache groups have different block sizes, the actual block size
# can be a multiple of hash_block_size.
self.hash_block_size = hash_block_size
assert all(
g.kv_cache_spec.block_size % hash_block_size == 0
for g in kv_cache_config.kv_cache_groups
), "block_size must be divisible by hash_block_size"
assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
"""
Verifies that the model has exactly two types of KV cache groups, and
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_spec: FullAttentionSpec | None = None
other_spec: KVCacheSpec | None = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_spec is None:
full_attention_spec = g.kv_cache_spec
else:
assert full_attention_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now."
)
self.full_attention_group_ids.append(i)
else:
if other_spec is None:
other_spec = g.kv_cache_spec
else:
assert other_spec == g.kv_cache_spec, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now."
)
self.other_group_ids.append(i)
assert full_attention_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now."
)
assert other_spec is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other groups now."
)
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]
].__class__
self.full_attention_spec = full_attention_spec
self.other_spec = other_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
# The LCM of the block sizes of full attention and other attention.
# The cache hit length must be a multiple of the LCM of the block sizes
# to make sure the cache hit length is a multiple of the block size of
# each attention type. Requiring this because we don't support partial
# block cache hit yet.
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks."
)
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
"""
Find the longest cache hit for the request.
Args:
block_hashes: The block hashes of the request.
max_cache_hit_length: The maximum length of the cache hit.
Returns:
A tuple containing:
- A list of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
if self.full_attention_spec.block_size == self.hash_block_size:
# Common case.
full_attention_block_hashes: BlockHashList = block_hashes
else:
# block_size is a multiple of hash_block_size. This happens when different
# KV cache groups have different block sizes. In this case, we need to
# recalculate block_hashes at the granularity of block_size, using the
# original block_hashes (at the granularity of hash_block_size).
full_attention_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
)
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=full_attention_block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
if self.other_spec.block_size == self.hash_block_size:
# Common case.
other_block_hashes: BlockHashList = block_hashes
else:
# Similar to the full attention case, here we need to recalculate
# block_hashes at the granularity of block_size, using the original
# block_hashes (at the granularity of hash_block_size).
other_block_hashes = BlockHashListWithBlockSize(
block_hashes, self.hash_block_size, self.other_spec.block_size
)
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
block_hashes=other_block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
alignment_tokens=self.lcm_block_size,
)
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiple of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiple of the block size of
# full attention layers in current implementation, because hit_length is
# a multiple of other attention's block size, and other attention's
# block size is a multiple of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for group_hit_blocks in hit_blocks_full_attn:
del group_hit_blocks[hit_length // self.full_attention_block_size :]
# Merge the hit blocks of full attention and other attention.
if self.full_attn_first:
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
else:
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
return hit_blocks, hit_length
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
hash_block_size: int,
metrics_collector: KVCacheMetricsCollector | None = None,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
kv_cache_config,
max_model_len,
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)
return HybridKVCacheCoordinator(
kv_cache_config,
max_model_len,
use_eagle,
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=metrics_collector,
)

View File

@@ -0,0 +1,419 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Literal, overload
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: tuple[Sequence[KVCacheBlock], ...]
"""
`blocks[i][j]` refers to the i-th kv_cache_group
and the j-th block of tokens.We don't use block of
tokens as the outer dimension because it assumes all
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups in the future.
Each single type KVCacheBlocks could be represented as:
- list[KVCacheBlock] for more than one KVCacheBlock
- an empty tuple for requests without KVCacheBlock
(a precomputed KVCacheBlocks is in KVCacheManager to avoid GC overhead)
"""
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(
tuple(
list(itertools.chain(blk1, blk2))
for blk1, blk2 in zip(self.blocks, other.blocks)
)
)
@overload
def get_block_ids(
self,
allow_none: Literal[False] = False,
) -> tuple[list[int], ...]: ...
@overload
def get_block_ids(
self,
allow_none: Literal[True] = True,
) -> tuple[list[int], ...] | None: ...
def get_block_ids(
self,
allow_none: bool = False,
) -> tuple[list[int], ...] | None:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
tuple[list[int], ...]: A tuple of lists where:
- the outer tuple corresponds to KV cache groups
- each inner list contains the block_ids of the blocks in that
group
"""
if allow_none and all(len(group) == 0 for group in self.blocks):
return None
return tuple([blk.block_id for blk in group] for group in self.blocks)
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
assert len(self.blocks) == 1, "Only one group is supported"
return [block.block_id for block in self.blocks[0] if block.block_hash is None]
def new_empty(self) -> "KVCacheBlocks":
"""
Creates a new KVCacheBlocks instance with no blocks.
"""
return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))))
class KVCacheManager:
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
hash_block_size: int,
enable_caching: bool = True,
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
metrics_collector: KVCacheMetricsCollector | None = None,
) -> None:
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.use_eagle = use_eagle
self.log_stats = log_stats
self.metrics_collector = metrics_collector
# FIXME: make prefix cache stats conditional on log_stats. We still need
# this comment because when the log stats is enabled there are still
# potential configs we could expose in the future.
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
hash_block_size=hash_block_size,
metrics_collector=self.metrics_collector,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Pre-constructed KVCacheBlocks with no blocks, callers should use this
# via create_kv_cache_blocks instead of creating new ones to avoid GC
# overhead.
#
# We use nested tuples to ensure the empty KVCacheBlocks is immutable.
self.empty_kv_cache_blocks = KVCacheBlocks(
tuple(() for _ in range(self.num_kv_cache_groups))
)
@property
def usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return self.block_pool.get_usage()
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
"""
if not self.log_stats:
return None
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
# We skip finding the prefix cache hit when prefix caching is
# disabled or the request is marked as skipping kv cache read
# (which happens when the request requires prompt logprobs
# or calls a pooling model with all pooling).
if not self.enable_caching or request.skip_reading_prefix_cache:
return self.empty_kv_cache_blocks, 0
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
# the single last token, because allocate_slots() requires
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(
request.block_hashes, max_cache_hit_length
)
)
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.record(
num_tokens=request.num_tokens,
num_hits=num_new_computed_tokens,
preempted=request.num_preemptions > 0,
)
return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
def allocate_slots(
self,
request: Request,
num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: KVCacheBlocks | None = None,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
num_encoder_tokens: int = 0,
) -> KVCacheBlocks | None:
"""Add slots for a request with new tokens to append.
Args:
request: The request to allocate slots.
num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
tokens.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
num_encoder_tokens: The number of encoder tokens to allocate for
cross-attention in encoder-decoder models(e.g., Whisper).
For decoder-only models, this should be 0.
Blocks layout:
```
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
```
The following *_blocks are illustrated in this layout.
Returns:
A list of new allocated blocks.
"""
if num_new_tokens == 0:
raise ValueError("num_new_tokens must be greater than 0")
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = self.empty_kv_cache_blocks.blocks
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
self.coordinator.remove_skipped_blocks(
request.request_id, request.num_computed_tokens
)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len,
)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
num_encoder_tokens=num_encoder_tokens,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert not any(new_computed_block_list), (
"Computed blocks should be empty when prefix caching is disabled"
)
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.coordinator.save_new_computed_blocks(
request.request_id, new_computed_block_list
)
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot, num_encoder_tokens
)
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return self.create_kv_cache_blocks(new_blocks)
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
# draft tokens that could be rejected). Therefore, we cap the number
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache = min(
num_computed_tokens + num_new_tokens, request.num_tokens
)
self.coordinator.cache_blocks(request, num_tokens_to_cache)
return self.create_kv_cache_blocks(new_blocks)
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
We free the blocks in reverse order so that the tail blocks are evicted
first when caching is enabled.
Args:
request: The request to free the blocks.
"""
self.coordinator.free(request.request_id)
def evict_blocks(self, block_ids: set[int]) -> None:
"""evict blocks from the prefix cache by their block IDs.
Args:
block_ids: Set of block IDs to evict from cache.
"""
self.block_pool.evict_blocks(block_ids)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
if not self.block_pool.reset_prefix_cache():
return False
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.reset = True
return True
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
"""Calculate the number of common prefix blocks for each kv cache group.
The function selects a running request and iterates through its blocks.
A block is considered a common prefix block if ALL requests with
allocated KV cache share it (i.e., ref_cnt equals the number of entries
in req_to_blocks).
NOTE(woosuk): The number of requests with allocated KV cache is **greater
than or equal to** the number of requests scheduled in the current step.
This is because having allocated KV cache only indicates that:
1. The request has not yet finished, and
2. The request holds its blocks unfreed.
While all scheduled requests must have allocated KV cache, the inverse
is not necessarily true. There may be requests with allocated KV cache
that are not scheduled in the current step.
This can result in an edge case where the number of common prefix blocks
is 0, even though all scheduled requests share a common prefix. This
occurs because there may be unscheduled requests that do not share the
common prefix. Currently, this case cannot be easily detected, so the
function returns 0 in such cases.
Args:
running_request_id: The request ID of any running request, used to
identify the common prefix blocks.
Returns:
list[int]: The number of common prefix blocks for each kv cache
group.
"""
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
Returns:
A list of KV cache events.
"""
return self.block_pool.take_events()
def get_blocks(self, request_id: str) -> KVCacheBlocks:
"""Get the blocks of a request."""
return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id))
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request."""
return self.get_blocks(request_id).get_block_ids()
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
if self.enable_caching:
self.coordinator.cache_blocks(request, num_computed_tokens)
def create_kv_cache_blocks(
self, blocks: tuple[list[KVCacheBlock], ...]
) -> KVCacheBlocks:
# Only create new KVCacheBlocks for non-empty blocks
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks

View File

@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV cache metrics tracking."""
import random
import time
from collections import deque
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.v1.core.kv_cache_utils import KVCacheBlock
from vllm.v1.metrics.stats import KVCacheEvictionEvent
class BlockMetricsState:
"""Tracks lifecycle metrics for a single KV cache block."""
def __init__(self):
now_ns = time.monotonic_ns()
self.birth_time_ns = now_ns
self.last_access_ns = now_ns
# Bounded to prevent unbounded growth if a block is accessed many times.
self.access_history: deque[int] = deque(maxlen=4)
def record_access(self) -> None:
now_ns = time.monotonic_ns()
self.last_access_ns = now_ns
self.access_history.append(now_ns)
def get_lifetime_seconds(self) -> float:
now_ns = time.monotonic_ns()
return (now_ns - self.birth_time_ns) / 1e9
def get_idle_time_seconds(self) -> float:
now_ns = time.monotonic_ns()
return (now_ns - self.last_access_ns) / 1e9
def get_reuse_gaps_seconds(self) -> list[float]:
if len(self.access_history) < 2:
return []
history = list(self.access_history)
return [(history[i] - history[i - 1]) / 1e9 for i in range(1, len(history))]
class KVCacheMetricsCollector:
"""Collects KV cache residency metrics with sampling."""
def __init__(self, sample_rate: float = 0.01):
assert 0 < sample_rate <= 1.0, (
f"sample_rate must be in (0, 1.0], got {sample_rate}"
)
self.sample_rate = sample_rate
self.block_metrics: dict[int, BlockMetricsState] = {}
self._eviction_events: list[KVCacheEvictionEvent] = []
def should_sample_block(self) -> bool:
return random.random() < self.sample_rate
def on_block_allocated(self, block: "KVCacheBlock") -> None:
if self.should_sample_block():
self.block_metrics[block.block_id] = BlockMetricsState()
def on_block_accessed(self, block: "KVCacheBlock") -> None:
metrics = self.block_metrics.get(block.block_id)
if metrics:
metrics.record_access()
def on_block_evicted(self, block: "KVCacheBlock") -> None:
metrics = self.block_metrics.pop(block.block_id, None)
if not metrics:
return
lifetime = metrics.get_lifetime_seconds()
idle_time = metrics.get_idle_time_seconds()
reuse_gaps = tuple(metrics.get_reuse_gaps_seconds())
self._eviction_events.append(
KVCacheEvictionEvent(
lifetime_seconds=lifetime,
idle_seconds=idle_time,
reuse_gaps_seconds=reuse_gaps,
)
)
def reset(self) -> None:
"""Clear all state on cache reset."""
self.block_metrics.clear()
self._eviction_events.clear()
def drain_events(self) -> list[KVCacheEvictionEvent]:
events = self._eviction_events
self._eviction_events = []
return events

File diff suppressed because it is too large Load Diff

View File

View File

@@ -0,0 +1,68 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.logger import init_logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
class AsyncScheduler(Scheduler):
def _update_after_schedule(
self,
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
if (
request.num_computed_tokens
== request.num_tokens
+ request.num_output_placeholders
+ cur_num_spec_tokens
):
# The request will generate a new token plus num_spec_tokens
# in this scheduling step.
request.num_output_placeholders += 1 + cur_num_spec_tokens
# Add placeholders for the new tokens in spec_token_ids.
# We will update the actual spec token ids in the worker process.
request.spec_token_ids = [-1] * self.num_spec_tokens
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output(
self,
request: Request,
new_token_ids: list[int],
) -> tuple[list[int], bool]:
if request.discard_latest_async_tokens:
# If the request is force preempted in reset_prefix_cache, we
# should discard the latest async token.
request.discard_latest_async_tokens = False
return [], False
status_before_update = request.status
new_token_ids, stopped = super()._update_request_with_output(
request, new_token_ids
)
# Update the number of output placeholders.
request.num_output_placeholders -= len(new_token_ids)
assert request.num_output_placeholders >= 0
# Cache the new tokens. Preempted requests should be skipped.
if status_before_update == RequestStatus.RUNNING:
self.kv_cache_manager.cache_blocks(
request, request.num_computed_tokens - request.num_output_placeholders
)
return new_token_ids, stopped

View File

@@ -0,0 +1,189 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
class SchedulerInterface(ABC):
@abstractmethod
def __init__(
self,
vllm_config: "VllmConfig",
kv_cache_config: "KVCacheConfig",
structured_output_manager: "StructuredOutputManager",
block_size: int,
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
include_finished_set: bool = False,
log_stats: bool = False,
) -> None:
raise NotImplementedError
@abstractmethod
def schedule(self) -> "SchedulerOutput":
"""Schedule the requests to process in this scheduling step.
The scheduling decision is made at the iteration level. Each scheduling
step corresponds to a single forward pass of the model. Therefore, this
method is called repeatedly by a busy loop in the engine.
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
that specifies how many tokens to process for each request in this
scheduling step. For example, num_tokens can be as large as the number
of prompt tokens for new requests, or it can be 1 for the requests that
are auto-regressively generating new tokens one by one. Otherwise, it
can be somewhere in between in case of chunked prefills, prefix caching,
speculative decoding, etc.
Additionally, the scheduler also returns useful data about each request
or the batch as a whole. The model runner will use this information in
preparing inputs to the model.
Returns:
A SchedulerOutput object containing information about the scheduled
requests.
"""
raise NotImplementedError
@abstractmethod
def get_grammar_bitmask(
self, scheduler_output: "SchedulerOutput"
) -> "GrammarOutput | None":
raise NotImplementedError
@abstractmethod
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> dict[int, "EngineCoreOutputs"]:
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
requests. The model runner output includes generated token ids, draft
token ids for next step, etc. The scheduler uses this information to
update its states, checks the finished requests, and returns the output
for each request.
Returns:
A dict of client index to EngineCoreOutputs object containing the
outputs for each request originating from that client.
"""
raise NotImplementedError
@abstractmethod
def update_draft_token_ids(
self,
draft_token_ids: "DraftTokenIds",
) -> None:
"""Update the draft token ids for the scheduled requests."""
raise NotImplementedError
@abstractmethod
def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue.
Args:
request: The new request being added.
"""
raise NotImplementedError
@abstractmethod
def finish_requests(
self,
request_ids: str | Iterable[str],
finished_status: "RequestStatus",
) -> None:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
This method is called in two cases:
1. When the request is aborted by the client.
2. When the frontend process detects a stop string of the request after
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
finished_status: The finished status of the given requests.
"""
raise NotImplementedError
@abstractmethod
def get_num_unfinished_requests(self) -> int:
"""Number of unfinished requests in the scheduler's internal queue."""
raise NotImplementedError
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests in the scheduler's
internal queue."""
return self.get_num_unfinished_requests() > 0
@abstractmethod
def has_finished_requests(self) -> bool:
"""Returns True if there are finished requests that need to be cleared.
NOTE: This is different from `not self.has_unfinished_requests()`.
The scheduler maintains an internal list of the requests finished in the
previous step. This list is returned from the next call to schedule(),
to be sent to the model runner in the next step to clear cached states
for these finished requests.
This method checks if this internal list of finished requests is
non-empty. This information is useful for DP attention.
"""
raise NotImplementedError
def has_requests(self) -> bool:
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
Args:
reset_running_requests: If True, all the running requests will be
preempted and moved to the waiting queue. Otherwise, this method
will only reset the KV prefix cache when there is no running request
taking KV cache.
"""
raise NotImplementedError
@abstractmethod
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
raise NotImplementedError
@abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the scheduler."""
raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
return None

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING
from typing_extensions import deprecated
from vllm._bc_linter import bc_linter_include
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalFeatureSpec
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
else:
ECConnectorMetadata = object
KVConnectorMetadata = object
LoRARequest = object
MultiModalFeatureSpec = object
PoolingParams = object
SamplingParams = object
Request = object
@bc_linter_include
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: list[int] | None
mm_features: list[MultiModalFeatureSpec]
sampling_params: SamplingParams | None
pooling_params: PoolingParams | None
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: LoRARequest | None
prompt_embeds: "torch.Tensor | None" = None
# Only used for v2 model runner.
prefill_token_ids: list[int] | None = None
@classmethod
def from_request(
cls,
request: Request,
block_ids: tuple[list[int], ...],
prefill_token_ids: list[int] | None = None,
) -> "NewRequestData":
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_features=request.mm_features,
sampling_params=request.sampling_params,
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
prompt_embeds=request.prompt_embeds,
prefill_token_ids=prefill_token_ids,
)
def __repr__(self) -> str:
prompt_embeds_shape = (
self.prompt_embeds.shape if self.prompt_embeds is not None else None
)
return (
f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"prefill_token_ids={self.prefill_token_ids},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")"
)
# Version of __repr__ with the prompt data obfuscated
def anon_repr(self) -> str:
prompt_token_ids_len = (
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
)
prompt_embeds_shape = (
self.prompt_embeds.shape if self.prompt_embeds is not None else None
)
return (
f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={prompt_token_ids_len},"
f"mm_features={self.mm_features},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request},"
f"prompt_embeds_shape={prompt_embeds_shape}"
")"
)
@bc_linter_include
@dataclass
class CachedRequestData:
req_ids: list[str]
# For request ids not in resumed_req_ids, new_block_ids will be appended to
# the request's block IDs. For those in the set, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_req_ids: set[str]
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
# When PP is not used, new_token_ids will be empty.
new_token_ids: list[list[int]]
# For requests not scheduled in the last step, propagate the token ids to the
# connector. Won't contain requests that were scheduled in the prior step.
all_token_ids: dict[str, list[int]]
new_block_ids: list[tuple[list[int], ...] | None]
num_computed_tokens: list[int]
num_output_tokens: list[int]
@property
def num_reqs(self) -> int:
return len(self.req_ids)
@cached_property
@deprecated("This will be removed in v0.14, use `resumed_req_ids` instead.")
def resumed_from_preemption(self) -> list[bool]:
return [req_id in self.resumed_req_ids for req_id in self.req_ids]
@cached_property
@deprecated("This will be removed in v0.14, use `all_token_ids` instead.")
def resumed_req_token_ids(self) -> list[list[int] | None]:
return [
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
for req_id in self.req_ids
]
@classmethod
def make_empty(cls) -> "CachedRequestData":
return cls(
req_ids=[],
resumed_req_ids=set(),
new_token_ids=[],
all_token_ids={},
new_block_ids=[],
num_computed_tokens=[],
num_output_tokens=[],
)
@bc_linter_include
@dataclass
class SchedulerOutput:
# list of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: list[NewRequestData]
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: CachedRequestData
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: dict[str, list[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests in each KV cache group.
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of mm_hash strings associated with the encoder outputs to be
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# Request IDs that are preempted in this step.
# Only used for v2 model runner.
preempted_req_ids: set[str] | None = None
# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False
# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None
# EC Cache Connector metadata
ec_connector_metadata: ECConnectorMetadata | None = None
@classmethod
def make_empty(cls) -> "SchedulerOutput":
return cls(
scheduled_new_reqs=[],
scheduled_cached_reqs=CachedRequestData.make_empty(),
num_scheduled_tokens={},
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
)
@dataclass
class GrammarOutput:
# ids of structured output requests.
structured_output_request_ids: list[str]
# Bitmask ordered as structured_output_request_ids.
grammar_bitmask: "npt.NDArray[np.int32]"

View File

@@ -0,0 +1,217 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import heapq
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Iterable, Iterator
from enum import Enum
from vllm.v1.request import Request
class SchedulingPolicy(Enum):
"""Enum for scheduling policies."""
FCFS = "fcfs"
PRIORITY = "priority"
class RequestQueue(ABC):
"""Abstract base class for request queues."""
@abstractmethod
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to the policy."""
pass
@abstractmethod
def pop_request(self) -> Request:
"""Pop a request from the queue according to the policy."""
pass
@abstractmethod
def peek_request(self) -> Request:
"""Peek at the request at the front of the queue without removing it."""
pass
@abstractmethod
def prepend_request(self, request: Request) -> None:
"""Prepend a request to the front of the queue."""
pass
@abstractmethod
def prepend_requests(self, requests: "RequestQueue") -> None:
"""Prepend all requests from another queue to the front of this
queue."""
pass
@abstractmethod
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
pass
@abstractmethod
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
pass
@abstractmethod
def __bool__(self) -> bool:
"""Check if queue has any requests."""
pass
@abstractmethod
def __len__(self) -> int:
"""Get number of requests in queue."""
pass
@abstractmethod
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to the policy."""
pass
@abstractmethod
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse order."""
pass
class FCFSRequestQueue(deque[Request], RequestQueue):
"""A first-come-first-served queue that supports deque operations."""
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to FCFS policy."""
self.append(request)
def pop_request(self) -> Request:
"""Pop a request from the queue according to FCFS policy."""
return self.popleft()
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
if not self:
raise IndexError("peek from an empty queue")
return self[0]
def prepend_request(self, request: Request) -> None:
"""Prepend a request to the front of the queue."""
self.appendleft(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Prepend all requests from another queue to the front of this
queue."""
self.extendleft(reversed(requests))
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self.remove(request)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = set(requests)
filtered_requests = [req for req in self if req not in requests_to_remove]
# deque does not support in-place filtering, so we need to clear
# and extend
self.clear()
self.extend(filtered_requests)
def __bool__(self) -> bool:
"""Check if queue has any requests."""
return len(self) > 0
def __len__(self) -> int:
"""Get number of requests in queue."""
return super().__len__()
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to FCFS policy."""
return super().__iter__()
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse order."""
return super().__reversed__()
class PriorityRequestQueue(RequestQueue):
"""
A priority queue that supports heap operations.
Respects the ordering defined in the Request class, where
requests with a smaller value of `priority` are processed first.
If multiple requests have the same priority, the one with the earlier
`arrival_time` is processed first.
"""
def __init__(self) -> None:
self._heap: list[Request] = []
def add_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy."""
heapq.heappush(self._heap, request)
def pop_request(self) -> Request:
"""Pop a request from the queue according to priority policy."""
if not self._heap:
raise IndexError("pop from empty heap")
return heapq.heappop(self._heap)
def peek_request(self) -> Request:
"""Peek at the next request in the queue without removing it."""
if not self._heap:
raise IndexError("peek from empty heap")
return self._heap[0]
def prepend_request(self, request: Request) -> None:
"""Add a request to the queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
self.add_request(request)
def prepend_requests(self, requests: RequestQueue) -> None:
"""Add all requests from another queue according to priority policy.
Note: In a priority queue, there is no concept of prepending to the
front. Requests are ordered by (priority, arrival_time)."""
for request in requests:
self.add_request(request)
def remove_request(self, request: Request) -> None:
"""Remove a specific request from the queue."""
self._heap.remove(request)
heapq.heapify(self._heap)
def remove_requests(self, requests: Iterable[Request]) -> None:
"""Remove multiple specific requests from the queue."""
requests_to_remove = requests if isinstance(requests, set) else set(requests)
self._heap = [r for r in self._heap if r not in requests_to_remove]
heapq.heapify(self._heap)
def __bool__(self) -> bool:
"""Check if queue has any requests."""
return bool(self._heap)
def __len__(self) -> int:
"""Get number of requests in queue."""
return len(self._heap)
def __iter__(self) -> Iterator[Request]:
"""Iterate over the queue according to priority policy."""
heap_copy = self._heap[:]
while heap_copy:
yield heapq.heappop(heap_copy)
def __reversed__(self) -> Iterator[Request]:
"""Iterate over the queue in reverse priority order."""
return reversed(list(self))
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
"""Create request queue based on scheduling policy."""
if policy == SchedulingPolicy.PRIORITY:
return PriorityRequestQueue()
elif policy == SchedulingPolicy.FCFS:
return FCFSRequestQueue()
else:
raise ValueError(f"Unknown scheduling policy: {policy}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
from vllm.v1.request import Request, RequestStatus
def remove_all(lst: list, items_to_remove: set) -> list:
"""Remove all items from a list that are in the items_to_remove set.
This method optimizes for the common case of removing a single item,
falling back to list comprehension for multiple items.
Args:
lst: The list to remove items from
items_to_remove: Set of items to remove
Returns:
Either the modified original list (for single item removal) or
a new list (for multiple item removal). Callers should use the
returned value.
Note:
For single item removal, this modifies the original list in-place
and returns it. For multiple items, it creates and returns a new list.
"""
if not items_to_remove:
return lst
if len(items_to_remove) == 1:
# Fast path for single item removal (most common case)
item = next(iter(items_to_remove))
with contextlib.suppress(ValueError):
lst.remove(item)
return lst
# For multiple items, use list comprehension
return [item for item in lst if item not in items_to_remove]
def check_stop(request: Request, max_model_len: int) -> bool:
assert not request.pooling_params
sampling_params = request.sampling_params
assert sampling_params is not None
if request.num_output_tokens < sampling_params.min_tokens:
return False
last_token_id = request.output_token_ids[-1]
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
if (
request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens
):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
return False

View File

@@ -0,0 +1,801 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Sequence
from vllm.utils.math_utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
from vllm.v1.kv_cache_interface import (
ChunkedLocalAttentionSpec,
CrossAttentionSpec,
FullAttentionSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
from vllm.v1.request import Request
class SingleTypeKVCacheManager(ABC):
"""
An abstract base class for a manager that handle the kv cache management
logic of one specific type of attention layer.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
"""
self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
self.pcp_world_size = pcp_world_size
if dcp_world_size * pcp_world_size > 1:
self.block_size *= dcp_world_size * pcp_world_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for preempted ones.
self.num_cached_block: dict[str, int] = {}
self.kv_cache_group_id = kv_cache_group_id
self._null_block = block_pool.null_block
def get_num_blocks_to_allocate(
self,
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (
num_required_blocks
- len(new_computed_blocks)
- len(self.req_to_blocks[request_id])
)
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
)
return num_new_blocks + num_evictable_computed_blocks
def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
if request_id not in self.num_cached_block:
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(new_computed_blocks)
else:
# A running request. Should not have new computed blocks.
assert len(new_computed_blocks) == 0
def allocate_new_blocks(
self, request_id: str, num_tokens: int
) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
req_blocks = self.req_to_blocks[request_id]
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
num_full_blocks = num_tokens // self.block_size
if num_cached_blocks >= num_full_blocks:
return
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
)
self.num_cached_block[request.request_id] = num_full_blocks
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
# Free blocks in reverse order so that the tail blocks are
# freed first.
ordered_blocks = reversed(req_blocks)
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None)
@abstractmethod
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
Get the number of common prefix blocks for all requests with allocated
KV cache.
Args:
running_request_id: The request ID.
Returns:
The number of common prefix blocks for all requests with allocated
KV cache.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def find_longest_cache_hit(
cls,
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens). By default, it should
be set to the block_size.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
Returns:
A list of cached blocks with skipped blocks replaced by null block
for each kv cache group in `kv_cache_group_ids`.
Return a list of length `len(kv_cache_group_ids)`, where the i-th
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
raise NotImplementedError
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
"""
Remove and free the blocks that are no longer needed for attention computation.
The removed blocks should be replaced by null_block.
This function depends on `get_num_skipped_tokens`, which need to be implemented
differently for each attention type.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
# Remove the blocks that will be skipped during attention computation.
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
if num_skipped_tokens <= 0:
# This indicates that ALL tokens are inside attention window.
# Thus we do not need to free any blocks outside attention window.
# A typical case is full attention that we never free any token
# before the request is finished.
return
num_skipped_blocks = num_skipped_tokens // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
# Because the block starts from index 0, the num_skipped_block-th block
# corresponds to index num_skipped_blocks - 1.
for i in range(num_skipped_blocks - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
# The default behavior is to not skip any tokens.
return 0
class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
), (
"FullAttentionManager can only be used for full attention "
"and chunked local attention groups"
)
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
if dcp_world_size * pcp_world_size > 1:
block_size *= dcp_world_size * pcp_world_size
max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids
):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
if use_eagle and computed_blocks[0]:
# Need to drop the last matched block if eagle is enabled.
for computed in computed_blocks:
computed.pop()
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
return computed_blocks
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
blocks = self.req_to_blocks[running_request_id]
num_common_blocks = 0
for block in blocks:
if block.ref_cnt == len(self.req_to_blocks):
num_common_blocks += 1
else:
break
return num_common_blocks
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(
self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs
) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups"
)
assert dcp_world_size == 1, "DCP not support sliding window attn now."
assert pcp_world_size == 1, "PCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
sliding_window_contiguous_blocks = cdiv(
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size
)
if use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
sliding_window_contiguous_blocks += 1
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = tuple(
[block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
# Skip prefix matching check if the block is not aligned with
# `alignment_tokens`.
if (
num_contiguous_blocks == 0
and block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
# Add the cached block to the computed blocks.
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
for computed in computed_blocks:
del computed[i + num_contiguous_blocks :]
match_found = True
break
else:
num_contiguous_blocks = 0
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
while (
block_size != alignment_tokens # Faster for common case.
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
):
for computed in computed_blocks:
computed.pop()
if use_eagle and computed_blocks[0]:
assert kv_cache_spec.block_size == alignment_tokens, (
"aligned_length is not compatible with eagle now"
)
for computed in computed_blocks:
computed.pop()
return computed_blocks
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
For sliding window, this corresponds to the tokens that are prior to
the current sliding window.
Example:
sliding_window=4, num_computed_tokens=7
Tokens: [ 0 1 2 3 4 5 6 7 ]
| ---- computed -----|
^ next token to be computed
|-----------| sliding window for next token
|--skipped---|
The current window contains tokens 4~7. Tokens 0~3 will be skipped for
attention computation since they are outside the sliding window.
Thus, get_num_skipped_tokens(7) == 4.
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
return num_computed_tokens - self.sliding_window + 1
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
So it's not correct to count ref_cnt like FullAttentionManager. Return
0 here for correctness. Need to support cascade attention + sliding
window in the future.
"""
return 0
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
def __init__(
self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs
) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
prefix of the blocks that is not longer than `max_length`. The prefix
should be a common prefix hit for all the kv cache groups in
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
note we mark as computed if the whole block is outside of the local
window, and set the block as null. Examples:
1. Attention chunk size of 8, block size of 4, max length of 15
for next token at 15th (zero-indexed), 8th - 14th tokens are in
the window(needs lookup), 0th - 7th are not in the window,
so they are already marked as computed. We check the complete
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
[null, null, block 3], otherwise, we return [null, null]
2. Attention chunk size of 8, block size of 4, max length of 16
for next token at 16th (zero-indexed), 0th - 15th tokens are not
in the window, so they are already marked as computed.
we return 4 blocks[null, null, null, null]
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
dcp_world_size: The world size of decode context parallelism.
pcp_world_size: The world size of prefill context parallelism.
alignment_tokens: The returned cache hit length (in tokens) should
be a multiple of this value (in tokens).
Returns:
A list of cached blocks
"""
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
"ChunkedLocalAttentionManager can only be used for "
+ "chunked local attention groups"
)
assert use_eagle is False, (
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
assert kv_cache_spec.block_size == alignment_tokens, (
"KV cache groups with different block sizes are not compatible with "
"chunked local attention now"
)
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
max_length
// kv_cache_spec.attention_chunk_size
* kv_cache_spec.attention_chunk_size
)
else:
local_attention_start_idx = 0
# we marked blocks out of window as computed
# with null blocks, and blocks inside window based on cache lookup
# result [null] [null] ... [null] [hit block 1 (1st block contain
# last window)] [hit block 2] ... [hit block x]
local_attention_start_block_idx = (
local_attention_start_idx // kv_cache_spec.block_size
)
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[block_pool.null_block] * local_attention_start_block_idx
for _ in range(len(kv_cache_group_ids))
)
for i in range(local_attention_start_block_idx, max_num_blocks):
block_hash = block_hashes[i]
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids
):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
return computed_blocks
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
"""
Get the number of tokens that will be skipped for attention computation.
For chunked local attention, this corresponds to the tokens that are on
the left side of the current chunk.
Example 1:
chunk size = 8, num_computed_tokens = 13
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| ----- computed ---------------|
^^ next token to be computed
|----------------| <-- attention window for
next token
|--- skipped -----|
Output: get_num_skipped_tokens(13) == 8
Example 2:
chunk size = 8, num_computed_tokens = 8
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
| --- computed ---|
^ next token to be computed
|--| <-- attention window for next token
| --- skipped ----|
Output: get_num_skipped_tokens(8) == 8
Example 3:
chunk size = 8, num_computed_tokens = 7
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|---computed---|
^ next token to be computed
|-----------------| <-- attention window for next token
no token should be skipped.
Output: get_num_skipped_tokens(7) == 0
Args:
num_computed_tokens: The number of tokens that have been computed.
Returns:
The number of tokens that will be skipped for attention computation.
"""
num_skipped_tokens = (
num_computed_tokens // self.attention_chunk_size
) * self.attention_chunk_size
return num_skipped_tokens
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by chunked local attention.
"""
return 0
class MambaManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, MambaSpec), (
"MambaManager can only be used for mamba groups"
)
assert dcp_world_size == 1, "DCP not support mamba now."
assert pcp_world_size == 1, "PCP not support mamba now."
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
max_num_blocks = max_length // block_size
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids
):
# When enable Mamba prefix caching, `block_size` will be aligned
# across full attention layers and Mamba layers to ensure the
# prefix hit length aligned at block
if (
block_size != alignment_tokens # Faster for common case.
and (i + 1) * block_size % alignment_tokens != 0
):
continue
for computed, cached in zip(computed_blocks, cached_block):
# the hit length logic later assumes:
# hit_length = len(hit_blocks_other_attn[0])
# * self.other_block_size
# so we insert dummy blocks at the beginning:
computed.extend([block_pool.null_block] * i)
computed.append(cached)
break # we just need the last match - early stopping
return computed_blocks
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
"""
cascade attention is not supported by mamba
"""
return 0
def get_num_blocks_to_allocate(
self,
request_id: str,
num_tokens: int,
new_computed_blocks: Sequence[KVCacheBlock],
) -> int:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size
* self.kv_cache_spec.num_speculative_blocks
)
return super().get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks
)
def allocate_new_blocks(
self, request_id: str, num_tokens: int
) -> list[KVCacheBlock]:
# Allocate extra `num_speculative_blocks` blocks for
# speculative decoding (MTP/EAGLE) with linear attention.
assert isinstance(self.kv_cache_spec, MambaSpec)
if self.kv_cache_spec.num_speculative_blocks > 0:
num_tokens += (
self.kv_cache_spec.block_size
* self.kv_cache_spec.num_speculative_blocks
)
return super().allocate_new_blocks(request_id, num_tokens)
class CrossAttentionManager(SingleTypeKVCacheManager):
"""Manager for cross-attention KV cache in encoder-decoder models."""
def save_new_computed_blocks(
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so `new_computed_blocks` should always be empty.
assert len(new_computed_blocks) == 0
def cache_blocks(self, request: Request, num_tokens: int) -> None:
# We do not cache blocks for cross-attention to be shared between
# requests, so this method is not relevant.
raise ValueError("Should not be called as prefix caching is disabled.")
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
# Cross-attention blocks contain request-specific encoder states
# and are not shared between different requests
return 0
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: BlockHashList,
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
alignment_tokens: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups"
)
# Cross-attention does not benefit from prefix caching since:
# 1. Encoder states are unique per request (different audio/image
# inputs)
# 2. Encoder states are computed once per request, not incrementally
# 3. No reusable prefix exists between different multimodal inputs
# Return empty blocks to indicate no cache hits
raise NotImplementedError("CrossAttentionManager does not support caching")
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
MLAAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
MambaSpec: MambaManager,
CrossAttentionSpec: CrossAttentionManager,
}
def get_manager_for_kv_cache_spec(
kv_cache_spec: KVCacheSpec, **kwargs
) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, **kwargs)
return manager