Sync from v0.13
This commit is contained in:
0
vllm/v1/core/__init__.py
Normal file
0
vllm/v1/core/__init__.py
Normal file
485
vllm/v1/core/block_pool.py
Normal file
485
vllm/v1/core/block_pool.py
Normal file
@@ -0,0 +1,485 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any
|
||||
|
||||
from vllm.distributed.kv_events import (
|
||||
MEDIUM_GPU,
|
||||
AllBlocksCleared,
|
||||
BlockRemoved,
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
BlockHashList,
|
||||
BlockHashListWithBlockSize,
|
||||
BlockHashWithGroupId,
|
||||
ExternalBlockHash,
|
||||
FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
get_block_hash,
|
||||
make_block_hash_with_group_id,
|
||||
maybe_convert_block_hash,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockHashToBlockMap:
|
||||
"""
|
||||
Cache of blocks that are used for prefix caching. It caches blocks
|
||||
from hash directly to a block or multiple blocks
|
||||
(i.e. {block_hash: KVCacheBlocks})
|
||||
- Mostly block_hash maps to a single KVCacheBlock, and KVCacheBlocks
|
||||
would simply be a KVCacheBlock.
|
||||
- Otherwise, KVCacheBlocks is a dict from {block_id: KVCacheBlock}
|
||||
|
||||
A cached block is a full block with a block hash that can be used
|
||||
for prefix caching.
|
||||
The cached block may be used by running requests or in the
|
||||
free_block_queue that could potentially be evicted.
|
||||
|
||||
NOTE #1: We currently don't de-duplicate the blocks in the cache,
|
||||
meaning that if a block becomes full and is cached, we don't check
|
||||
if there is already an identical block in the cache. This is because
|
||||
we want to make sure the allocated block IDs won't change so that
|
||||
block tables are append-only.
|
||||
NOTE #2: The union type is introduced in order to reduce GC costs
|
||||
from the inner dict.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict[
|
||||
BlockHashWithGroupId, KVCacheBlock | dict[int, KVCacheBlock]
|
||||
] = {}
|
||||
|
||||
def get_one_block(self, key: BlockHashWithGroupId) -> KVCacheBlock | None:
|
||||
"""
|
||||
Gets any block with the given block hash key.
|
||||
"""
|
||||
blocks = self._cache.get(key)
|
||||
if blocks is not None:
|
||||
if isinstance(blocks, KVCacheBlock):
|
||||
return blocks
|
||||
if isinstance(blocks, dict):
|
||||
return next(iter(blocks.values()))
|
||||
self._unexpected_blocks_type(blocks)
|
||||
return None
|
||||
|
||||
def insert(self, key: BlockHashWithGroupId, block: KVCacheBlock) -> None:
|
||||
"""
|
||||
Inserts the KVCacheBlock to the cache
|
||||
"""
|
||||
blocks = self._cache.get(key)
|
||||
if blocks is None:
|
||||
# When key is not found, attach a single block to the key
|
||||
self._cache[key] = block
|
||||
elif isinstance(blocks, KVCacheBlock):
|
||||
# If there's a block with the same key, merge the original block
|
||||
# and the new block into a dict
|
||||
self._cache[key] = {blocks.block_id: blocks, block.block_id: block}
|
||||
elif isinstance(blocks, dict):
|
||||
# If it's already a dict, simply insert the block
|
||||
blocks[block.block_id] = block
|
||||
else:
|
||||
self._unexpected_blocks_type(blocks)
|
||||
|
||||
def pop(self, key: BlockHashWithGroupId, block_id: int) -> KVCacheBlock | None:
|
||||
"""
|
||||
Checks if block_hash exists and pop block_id from the cache
|
||||
"""
|
||||
blocks = self._cache.pop(key, None)
|
||||
if blocks is None:
|
||||
# block_hash not found in the cache
|
||||
return None
|
||||
# TODO(Jialin): If key is found, block_id should always present
|
||||
# in blocks. We currently keep the original behaviour for safety.
|
||||
#
|
||||
# Will add block_id == blocks.block_id assertion and
|
||||
# use del blocks[block_id] instead as followup.
|
||||
if isinstance(blocks, KVCacheBlock):
|
||||
if blocks.block_id == block_id:
|
||||
return blocks
|
||||
# If the single block ID doesn't match, we should put the
|
||||
# block back (it should happen rarely)
|
||||
self._cache[key] = blocks
|
||||
return None
|
||||
if isinstance(blocks, dict):
|
||||
# Try to pop block_id from the block dict, and if dict still
|
||||
# contain blocks, put back to the cache.
|
||||
block = blocks.pop(block_id, None)
|
||||
if len(blocks) > 0:
|
||||
self._cache[key] = blocks
|
||||
return block
|
||||
self._unexpected_blocks_type(blocks)
|
||||
return None
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._cache)
|
||||
|
||||
def _unexpected_blocks_type(self, blocks: Any) -> None:
|
||||
raise AssertionError(f"Invalid KV cache block type {type(blocks)}")
|
||||
|
||||
|
||||
class BlockPool:
|
||||
"""BlockPool that manages KVCacheBlocks.
|
||||
It provides methods to allocate, free and cache the kv cache blocks. The
|
||||
free_block_queue stores the free blocks in eviction order to enable
|
||||
allocation, free, and cache eviction. The cached_block_hash_to_block
|
||||
maps between block hash and cached block to support finding cached blocks
|
||||
by their block hash.
|
||||
|
||||
Args:
|
||||
num_gpu_blocks: The number of blocks in the pool.
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
hash_block_size: The block size of which the block hashes are computed.
|
||||
The actual block size usually equals hash_block_size, but in cases
|
||||
where different KV cache groups have different block sizes, the
|
||||
actual block size can be a multiple of hash_block_size.
|
||||
enable_kv_cache_events: Whether to enable kv cache events.
|
||||
metrics_collector: Optional metrics collector for tracking block residency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
enable_caching: bool,
|
||||
hash_block_size: int,
|
||||
enable_kv_cache_events: bool = False,
|
||||
metrics_collector: KVCacheMetricsCollector | None = None,
|
||||
):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
self.hash_block_size = hash_block_size
|
||||
# All kv-cache blocks.
|
||||
self.blocks: list[KVCacheBlock] = [
|
||||
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
|
||||
]
|
||||
# Free block queue that constructs and manipulates a doubly linked
|
||||
# list of free blocks (including eviction candidates when caching is
|
||||
# enabled).
|
||||
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
|
||||
|
||||
# Cache for block lookup
|
||||
self.cached_block_hash_to_block: BlockHashToBlockMap = BlockHashToBlockMap()
|
||||
|
||||
# To represent a placeholder block with block_id=0.
|
||||
# The ref_cnt of null_block is not maintained, needs special care to
|
||||
# avoid freeing it.
|
||||
self.null_block = self.free_block_queue.popleft()
|
||||
self.null_block.is_null = True
|
||||
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
|
||||
self.metrics_collector = metrics_collector
|
||||
|
||||
def get_cached_block(
|
||||
self, block_hash: BlockHash, kv_cache_group_ids: list[int]
|
||||
) -> list[KVCacheBlock] | None:
|
||||
"""Get the cached block by the block hash for each group in
|
||||
`kv_cache_group_ids`, or None if cache miss for any group.
|
||||
If there are duplicated blocks, we return the first block in the cache.
|
||||
|
||||
Args:
|
||||
block_hash: The hash value of the block.
|
||||
kv_cache_group_ids: The ids of the KV cache groups.
|
||||
|
||||
Returns:
|
||||
The cached blocks if exists, or None.
|
||||
"""
|
||||
cached_blocks = []
|
||||
for group_id in kv_cache_group_ids:
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, group_id
|
||||
)
|
||||
block = self.cached_block_hash_to_block.get_one_block(
|
||||
block_hash_with_group_id
|
||||
)
|
||||
if not block:
|
||||
return None
|
||||
cached_blocks.append(block)
|
||||
return cached_blocks
|
||||
|
||||
def cache_full_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
blocks: list[KVCacheBlock],
|
||||
num_cached_blocks: int,
|
||||
num_full_blocks: int,
|
||||
block_size: int,
|
||||
kv_cache_group_id: int,
|
||||
) -> None:
|
||||
"""Cache a list of full blocks for prefix caching.
|
||||
This function takes a list of blocks that will have their block hash
|
||||
metadata to be updated and cached. Given a request, it updates the
|
||||
metadata for each block and caching it in the
|
||||
`cached_block_hash_to_block`.
|
||||
The block hashes values are computed by the Request object immediately
|
||||
when it is created and when new tokens are appended.
|
||||
|
||||
Args:
|
||||
request: The request to cache the blocks.
|
||||
blocks: All blocks in the request.
|
||||
num_cached_blocks: The number of blocks that are already cached.
|
||||
num_full_blocks: The number of blocks that are full and should
|
||||
be cached after this function.
|
||||
block_size: Number of tokens in each block.
|
||||
kv_cache_group_id: The id of the KV cache group.
|
||||
"""
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
return
|
||||
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
|
||||
assert len(request.block_hashes) >= num_full_blocks
|
||||
if block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
block_hashes: BlockHashList = request.block_hashes
|
||||
else:
|
||||
# block_size is a multiple of hash_block_size. This happens when
|
||||
# different KV cache groups have different block sizes.
|
||||
assert block_size % self.hash_block_size == 0
|
||||
# Recalculate block_hashes at the granularity of block_size, using
|
||||
# the original block_hashes (at the granularity of hash_block_size).
|
||||
block_hashes = BlockHashListWithBlockSize(
|
||||
request.block_hashes, self.hash_block_size, block_size
|
||||
)
|
||||
|
||||
new_block_hashes = block_hashes[num_cached_blocks:]
|
||||
new_hashes: list[ExternalBlockHash] | None = (
|
||||
[] if self.enable_kv_cache_events else None
|
||||
)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
assert blk.block_hash is None
|
||||
block_hash = new_block_hashes[i]
|
||||
|
||||
# Update and added the full block to the cache.
|
||||
block_hash_with_group_id = make_block_hash_with_group_id(
|
||||
block_hash, kv_cache_group_id
|
||||
)
|
||||
blk.block_hash = block_hash_with_group_id
|
||||
self.cached_block_hash_to_block.insert(block_hash_with_group_id, blk)
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(maybe_convert_block_hash(block_hash))
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
if num_cached_blocks == 0:
|
||||
parent_block_hash: ExternalBlockHash | None = None
|
||||
else:
|
||||
parent_block = blocks[num_cached_blocks - 1]
|
||||
assert parent_block.block_hash is not None
|
||||
parent_block_hash = maybe_convert_block_hash(
|
||||
get_block_hash(parent_block.block_hash)
|
||||
)
|
||||
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=new_hashes,
|
||||
parent_block_hash=parent_block_hash,
|
||||
token_ids=request.all_token_ids[
|
||||
num_cached_blocks * block_size : num_full_blocks * block_size
|
||||
],
|
||||
block_size=block_size,
|
||||
lora_id=request.lora_request.adapter_id
|
||||
if request.lora_request
|
||||
else None,
|
||||
medium=MEDIUM_GPU,
|
||||
)
|
||||
)
|
||||
|
||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
Note that we do not check block cache in this function.
|
||||
|
||||
Args:
|
||||
num_blocks: The number of blocks to allocate.
|
||||
|
||||
Returns:
|
||||
A list of new block.
|
||||
"""
|
||||
if num_blocks > self.get_num_free_blocks():
|
||||
raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")
|
||||
|
||||
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
|
||||
|
||||
# In order to only iterate the list once, we duplicated code a bit
|
||||
if self.enable_caching:
|
||||
for block in ret:
|
||||
self._maybe_evict_cached_block(block)
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
if self.metrics_collector:
|
||||
self.metrics_collector.on_block_allocated(block)
|
||||
else:
|
||||
for block in ret:
|
||||
assert block.ref_cnt == 0
|
||||
block.ref_cnt += 1
|
||||
if self.metrics_collector:
|
||||
self.metrics_collector.on_block_allocated(block)
|
||||
return ret
|
||||
|
||||
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
|
||||
"""
|
||||
If a block is cached in `cached_block_hash_to_block`, we reset its hash
|
||||
metadata and evict it from the cache.
|
||||
|
||||
Args:
|
||||
block: The block to evict.
|
||||
|
||||
Returns:
|
||||
True if the block is evicted, False otherwise.
|
||||
"""
|
||||
# Clean up metrics tracking first to prevent leaks
|
||||
if self.metrics_collector:
|
||||
self.metrics_collector.on_block_evicted(block)
|
||||
|
||||
block_hash = block.block_hash
|
||||
if block_hash is None:
|
||||
# The block doesn't have hash, eviction is not needed
|
||||
return False
|
||||
|
||||
if self.cached_block_hash_to_block.pop(block_hash, block.block_id) is None:
|
||||
# block not found in cached_block_hash_to_block,
|
||||
# eviction is not needed
|
||||
return False
|
||||
|
||||
block.reset_hash()
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
# FIXME (Chen): Not sure whether we should return `hash_value`
|
||||
# or `(hash_value, group_id)` here. But it's fine now because
|
||||
# we disable hybrid kv cache manager when kv cache event is
|
||||
# enabled, so there is only one group.
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(
|
||||
block_hashes=[maybe_convert_block_hash(get_block_hash(block_hash))],
|
||||
medium=MEDIUM_GPU,
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
|
||||
"""Touch a block increases its reference count by 1, and may remove
|
||||
the block from the free queue. This is used when a block is hit by
|
||||
another request with the same prefix.
|
||||
|
||||
Args:
|
||||
blocks: A list of blocks to touch.
|
||||
"""
|
||||
for blocks_per_group in blocks:
|
||||
for block in blocks_per_group:
|
||||
# ref_cnt=0 means this block is in the free list (i.e. eviction
|
||||
# candidate), so remove it.
|
||||
if block.ref_cnt == 0 and not block.is_null:
|
||||
self.free_block_queue.remove(block)
|
||||
block.ref_cnt += 1
|
||||
if self.metrics_collector:
|
||||
self.metrics_collector.on_block_accessed(block)
|
||||
|
||||
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
|
||||
"""Free a list of blocks. The blocks should be ordered by their
|
||||
eviction priority, where the first block will be evicted first.
|
||||
|
||||
Args:
|
||||
ordered_blocks: A list of blocks to free ordered by their eviction
|
||||
priority.
|
||||
"""
|
||||
# Materialize the iterable to allow multiple passes.
|
||||
blocks_list = list(ordered_blocks)
|
||||
for block in blocks_list:
|
||||
block.ref_cnt -= 1
|
||||
self.free_block_queue.append_n(
|
||||
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
|
||||
)
|
||||
|
||||
def evict_blocks(self, block_ids: set[int]) -> None:
|
||||
"""evict blocks from the prefix cache by their block IDs.
|
||||
|
||||
only evicts blocks that are currently cached (have a hash). blocks
|
||||
with ref_cnt > 0 are not freed from the block pool, only evicted
|
||||
from the prefix cache hash table.
|
||||
|
||||
Args:
|
||||
block_ids: Set of block IDs to evict from cache.
|
||||
"""
|
||||
for block_id in block_ids:
|
||||
assert block_id < len(self.blocks), (
|
||||
f"Invalid block_id {block_id} >= {len(self.blocks)}. "
|
||||
f"This indicates a bug in the KV connector - workers should "
|
||||
f"only report block IDs that were allocated by the scheduler."
|
||||
)
|
||||
block = self.blocks[block_id]
|
||||
self._maybe_evict_cached_block(block)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalid prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
|
||||
if num_used_blocks != 1: # The null block is always marked as used
|
||||
logger.warning(
|
||||
"Failed to reset prefix cache because some "
|
||||
"blocks (%d) are not freed yet",
|
||||
num_used_blocks - 1,
|
||||
)
|
||||
return False
|
||||
|
||||
# Remove all hashes so that no new blocks will hit.
|
||||
self.cached_block_hash_to_block = BlockHashToBlockMap()
|
||||
|
||||
# Remove all hashes from all blocks.
|
||||
for block in self.blocks:
|
||||
block.reset_hash()
|
||||
|
||||
if self.metrics_collector:
|
||||
self.metrics_collector.reset()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(AllBlocksCleared())
|
||||
|
||||
return True
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
"""Get the number of free blocks in the pool.
|
||||
|
||||
Returns:
|
||||
The number of free blocks.
|
||||
"""
|
||||
return self.free_block_queue.num_free_blocks
|
||||
|
||||
def get_usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
|
||||
# Subtract 1 to account for null block.
|
||||
total_gpu_blocks = self.num_gpu_blocks - 1
|
||||
if not total_gpu_blocks:
|
||||
return 0
|
||||
return 1.0 - (self.get_num_free_blocks() / total_gpu_blocks)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Atomically takes all events and clears the queue.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
if not self.enable_kv_cache_events:
|
||||
return []
|
||||
events = self.kv_event_queue
|
||||
self.kv_event_queue = []
|
||||
return events
|
||||
402
vllm/v1/core/encoder_cache_manager.py
Normal file
402
vllm/v1/core/encoder_cache_manager.py
Normal file
@@ -0,0 +1,402 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.v1.request import Request
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EncoderCacheManager:
|
||||
"""Manages caching of encoder outputs for multimodal models in vLLM V1.
|
||||
|
||||
The EncoderCacheManager handles the lifecycle of multimodal encoder outputs
|
||||
(such as vision embeddings from images) during request processing. It
|
||||
provides memory-aware caching to avoid recomputing encoder outputs when the
|
||||
same multimodal inputs appear in different stages of request processing.
|
||||
|
||||
This manager is particularly important for:
|
||||
- Vision-language models (e.g., LLaVA) where image encoder outputs are
|
||||
cached
|
||||
- Any multimodal model where encoder computation is expensive and
|
||||
cacheable
|
||||
|
||||
The cache operates at the granularity of individual multimodal input items
|
||||
within requests, allowing for fine-grained memory management and enabling
|
||||
chunked processing of multimodal inputs.
|
||||
|
||||
Cache is enabled to share embeddings of same multimodal data
|
||||
item (identified by their hash value) between different requests,
|
||||
and eviction takes place at allocation time when there's no free
|
||||
space for new embeddings.
|
||||
Oldest cached embeddings with no request referenced will be first evicted.
|
||||
|
||||
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
|
||||
instead of encoder tokens (i.e. all tokens that represent the multimodal data
|
||||
in the input sequence). This means all break/text tokens in-between multimodal
|
||||
embeddings are not considered with respect to the cache size and the number
|
||||
of free slots.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
encoder embeddings from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder embeddings.
|
||||
num_free_slots: Current available cache capacity in encoder embeddings.
|
||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||
evicting entries with zero references (in encoder embeddings).
|
||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||
reference the cached entry. If the set is empty, the entry exists
|
||||
but is not referenced by any request and is eligible for
|
||||
reclamation.
|
||||
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
|
||||
whose no current running request is needed and that can be freed to
|
||||
make space when needed.
|
||||
freed: List of mm_hash strings that were actually evicted since the
|
||||
last call to get_freed_mm_hashes(). This list is cleared on return.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
self.num_free_slots = cache_size
|
||||
self.num_freeable_slots = cache_size
|
||||
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||
"""Check if encoder output for a specific multimodal input is cached.
|
||||
|
||||
If the encoder output is cached, update `cached` to add the request id
|
||||
to the set of request ids that reference the cached encoder output.
|
||||
If the encoder output was previously not referenced by any request,
|
||||
update `freeable` and `num_freeable_slots` accordingly.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input
|
||||
input_id: Index of the multimodal input within the request
|
||||
|
||||
Returns:
|
||||
True if the encoder output for this input is already cached
|
||||
"""
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
# Not cached at all
|
||||
if mm_hash not in self.cached:
|
||||
return False
|
||||
|
||||
# Cached but currently not referenced by any request
|
||||
if not self.cached[mm_hash]:
|
||||
num_encoder_embeds = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
return True
|
||||
|
||||
def can_allocate(
|
||||
self,
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
If there is, return True and update EncoderCacheManager state.
|
||||
|
||||
If there is not enough free space in `num_free_slots` but there is
|
||||
enough reclaimable space in `num_freeable_slots`, entries will be
|
||||
evicted from `freeable` (their mm_hash appended to `freed`) until
|
||||
enough space is available, and then this method returns True.
|
||||
Older entries are evicted first.
|
||||
|
||||
Returns False only if the requested number of tokens exceeds both
|
||||
the free and reclaimable capacities combined.
|
||||
|
||||
Args:
|
||||
request: The request containing the multimodal input.
|
||||
input_id: Index of the multimodal input within the request.
|
||||
encoder_compute_budget: Number of encoder embeddings allowed to be
|
||||
computed when this method is invoked.
|
||||
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
|
||||
allocated with cache space when this method is invoked.
|
||||
|
||||
Returns:
|
||||
True if there's enough capacity to hold the encoder output for this
|
||||
input (possibly after reclaiming `freeable` entries); otherwise
|
||||
False.
|
||||
|
||||
Note: This method does not allocate physical memory for the encoder
|
||||
output but only the state of EncoderCacheManager.
|
||||
"""
|
||||
num_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# Not enough compute budget
|
||||
if num_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_embeds += num_embeds_to_schedule
|
||||
|
||||
# Enough free slots
|
||||
if num_embeds <= self.num_free_slots:
|
||||
return True
|
||||
|
||||
# Not enough reclaimable slots
|
||||
if num_embeds > self.num_freeable_slots:
|
||||
return False
|
||||
|
||||
# Not enough free slots but enough reclaimable slots
|
||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||
# until model runner is notified by the scheduler output.
|
||||
while num_embeds > self.num_free_slots:
|
||||
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
|
||||
del self.cached[mm_hash]
|
||||
self.freed.append(mm_hash)
|
||||
self.num_free_slots += num_free_embeds
|
||||
return True
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
"""Allocate cache space for a multimodal input's encoder output.
|
||||
|
||||
This reserves cache space for storing the encoder output of the
|
||||
specified multimodal input. The actual encoder output storage happens in
|
||||
the model runner; this method updates the manager's bookkeeping.
|
||||
|
||||
Note:
|
||||
This method assumes can_allocate() returned True for the same input.
|
||||
"""
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
request_id = request.request_id
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||
# that are scheduled since eviction takes place at can_allocate().
|
||||
assert self.num_free_slots >= num_encoder_embeds
|
||||
assert self.num_freeable_slots >= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
|
||||
Returns the set of input IDs whose `mm_hash` exists in the cache map.
|
||||
This includes entries that are currently unreferenced (and thus present
|
||||
in `freeable`); for such entries, freeing for this request will be a
|
||||
no-op.
|
||||
"""
|
||||
return {
|
||||
input_id
|
||||
for input_id in range(len(request.mm_features))
|
||||
if request.mm_features[input_id].identifier in self.cached
|
||||
}
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
"""Free the request's reference to the encoder input (`mm_data`)
|
||||
|
||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||
increased by the number of encoder embeddings for that input.
|
||||
|
||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||
`can_allocate`).
|
||||
"""
|
||||
req_id = request.request_id
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
# The mm_hash not in cache or the req_id set is empty
|
||||
if not self.cached.get(mm_hash, None):
|
||||
return
|
||||
self.cached[mm_hash].discard(req_id)
|
||||
if not self.cached[mm_hash]:
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.freeable[mm_hash] = num_encoder_embeds
|
||||
self.num_freeable_slots += num_encoder_embeds
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all encoder input cache reference held by *request*.
|
||||
|
||||
For each cached input ID, `free_encoder_input` is invoked.
|
||||
The data stays in memory until eviction is triggered by a future
|
||||
attempt allocation called by 'can_allocate'.
|
||||
|
||||
Typically called when a request is finished, cancelled, or aborted.
|
||||
"""
|
||||
input_ids = self.get_cached_input_ids(request).copy()
|
||||
for input_id in input_ids:
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_freed_mm_hashes(self) -> list[str]:
|
||||
"""Get and clear the list of recently freed encoder cache entries.
|
||||
|
||||
Returns:
|
||||
List of mm_hash strings that were actually evicted since the last
|
||||
call to be used by the scheduler to notify workers about which
|
||||
encoder outputs can be removed from their caches. The internal
|
||||
list is cleared after this call.
|
||||
"""
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
return freed
|
||||
|
||||
|
||||
def compute_encoder_budget(
|
||||
model_config: "ModelConfig",
|
||||
scheduler_config: "SchedulerConfig",
|
||||
mm_registry: MultiModalRegistry,
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, measured in number of tokens
|
||||
from the input sequence.
|
||||
- Space budget for encoder cache size, measured in number of tokens
|
||||
from the input sequence.
|
||||
"""
|
||||
if mm_registry.supports_multimodal_inputs(model_config):
|
||||
max_tokens_by_modality = mm_registry.get_max_tokens_per_item_by_modality(
|
||||
model_config
|
||||
)
|
||||
|
||||
return compute_mm_encoder_budget(
|
||||
scheduler_config,
|
||||
max_tokens_by_modality,
|
||||
)
|
||||
|
||||
return compute_text_encoder_budget(scheduler_config)
|
||||
|
||||
|
||||
def compute_text_encoder_budget(scheduler_config: "SchedulerConfig") -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations for a text-only model.
|
||||
|
||||
Args:
|
||||
scheduler_config: Scheduler configuration.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, in unit of number of tokens
|
||||
in the input sequence.
|
||||
- Space budget for encoder cache size, in unit of number of tokens
|
||||
in the input sequence.
|
||||
"""
|
||||
# Currently text-only encoder-decoder models are not supported
|
||||
return 0, 0
|
||||
|
||||
|
||||
def compute_mm_encoder_budget(
|
||||
scheduler_config: "SchedulerConfig",
|
||||
max_tokens_by_modality: Mapping[str, int],
|
||||
) -> tuple[int, int]:
|
||||
"""Compute the encoder cache budget based on the model and scheduler
|
||||
configurations for a multimodal model.
|
||||
|
||||
Args:
|
||||
scheduler_config: Scheduler configuration.
|
||||
max_tokens_by_modality: The maximum number of tokens for each
|
||||
non-text modality.
|
||||
|
||||
Returns:
|
||||
- Compute budget for encoder execution, measured in number of tokens
|
||||
from the input sequence.
|
||||
- Space budget for encoder cache size, measured in number of tokens
|
||||
from the input sequence.
|
||||
"""
|
||||
|
||||
if not max_tokens_by_modality:
|
||||
logger.warning(
|
||||
"All non-text modalities supported by the model have been "
|
||||
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
|
||||
"not be initialized."
|
||||
)
|
||||
return 0, 0
|
||||
|
||||
max_tokens_per_mm_item = max(max_tokens_by_modality.values())
|
||||
|
||||
if (
|
||||
scheduler_config.disable_chunked_mm_input
|
||||
and max_tokens_per_mm_item > scheduler_config.max_num_batched_tokens
|
||||
):
|
||||
raise ValueError(
|
||||
"Chunked MM input disabled but max_tokens_per_mm_item "
|
||||
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
|
||||
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
|
||||
"max_num_batched_tokens."
|
||||
)
|
||||
|
||||
encoder_compute_budget = max(
|
||||
scheduler_config.max_num_encoder_input_tokens, max_tokens_per_mm_item
|
||||
)
|
||||
encoder_cache_size = max(
|
||||
scheduler_config.encoder_cache_size, max_tokens_per_mm_item
|
||||
)
|
||||
|
||||
return encoder_compute_budget, encoder_cache_size
|
||||
|
||||
|
||||
# NOTE (NickLucche): Temporary implementation for encoder-decoder models that only
|
||||
# use the manager for scheduling purposes. Encoder-decoder models will eventually
|
||||
# utilize the cache and this class will fold into EncoderCacheManager, as
|
||||
# differences with MM models shrink.
|
||||
class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
def __init__(self, cache_size: int):
|
||||
self.cache_size = cache_size
|
||||
self.num_free_slots = cache_size
|
||||
self.freed: list[str] = []
|
||||
|
||||
def check_and_update_cache(self, request: Request, input_id: int) -> bool:
|
||||
return False
|
||||
|
||||
def can_allocate(
|
||||
self,
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
# Not enough compute budget
|
||||
if num_encoder_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_encoder_embeds += num_embeds_to_schedule
|
||||
# Enough free slots
|
||||
return num_encoder_embeds <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
self.freed.append(mm_hash)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
for input_id in range(len(request.mm_features)):
|
||||
self.free_encoder_input(request, input_id)
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
return set(range(len(request.mm_features)))
|
||||
|
||||
def get_freed_mm_hashes(self) -> list[str]:
|
||||
freed = self.freed
|
||||
self.freed = []
|
||||
return freed
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots += num_encoder_embeds
|
||||
570
vllm/v1/core/kv_cache_coordinator.py
Normal file
570
vllm/v1/core/kv_cache_coordinator.py
Normal file
@@ -0,0 +1,570 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from math import lcm
|
||||
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
BlockHashList,
|
||||
BlockHashListWithBlockSize,
|
||||
KVCacheBlock,
|
||||
)
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (
|
||||
CrossAttentionManager,
|
||||
FullAttentionManager,
|
||||
get_manager_for_kv_cache_spec,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheConfig,
|
||||
KVCacheSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class KVCacheCoordinator(ABC):
|
||||
"""
|
||||
Coordinate the KV cache of different KV cache groups.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
metrics_collector: KVCacheMetricsCollector | None = None,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.block_pool = BlockPool(
|
||||
kv_cache_config.num_blocks,
|
||||
enable_caching,
|
||||
hash_block_size,
|
||||
enable_kv_cache_events,
|
||||
metrics_collector,
|
||||
)
|
||||
|
||||
# Needs special handling for find_longest_cache_hit if eagle is enabled
|
||||
self.use_eagle = use_eagle
|
||||
self.single_type_managers = tuple(
|
||||
get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec=kv_cache_group.kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
|
||||
)
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: tuple[Sequence[KVCacheBlock], ...],
|
||||
num_encoder_tokens: int,
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
num_blocks_to_allocate = 0
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
if isinstance(manager, CrossAttentionManager):
|
||||
# For cross-attention, we issue a single static allocation
|
||||
# of blocks based on the number of encoder input tokens.
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_encoder_tokens, []
|
||||
)
|
||||
else:
|
||||
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks[i]
|
||||
)
|
||||
return num_blocks_to_allocate
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: tuple[Sequence[KVCacheBlock], ...]
|
||||
) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
for i, manager in enumerate(self.single_type_managers):
|
||||
manager.save_new_computed_blocks(request_id, new_computed_blocks[i])
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int, num_encoder_tokens: int = 0
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
num_encoder_tokens: The number of encoder tokens for allocating
|
||||
blocks for cross-attention.
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
return tuple(
|
||||
manager.allocate_new_blocks(
|
||||
request_id,
|
||||
num_encoder_tokens
|
||||
if isinstance(manager, CrossAttentionManager)
|
||||
else num_tokens,
|
||||
)
|
||||
for manager in self.single_type_managers
|
||||
)
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
num_computed_tokens: The total number of tokens
|
||||
that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.free(request_id)
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
"""
|
||||
Get the number of common prefix blocks for all requests with allocated
|
||||
KV cache for each kv cache group.
|
||||
|
||||
Args:
|
||||
running_request_id: The request ID of any running request, used to
|
||||
identify the common prefix blocks.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for each kv cache group.
|
||||
"""
|
||||
return [
|
||||
manager.get_num_common_prefix_blocks(running_request_id)
|
||||
for manager in self.single_type_managers
|
||||
]
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks` and replace
|
||||
the removed blocks with null_block.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
for manager in self.single_type_managers:
|
||||
manager.remove_skipped_blocks(request_id, num_computed_tokens)
|
||||
|
||||
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the blocks for the request.
|
||||
"""
|
||||
return tuple(
|
||||
manager.req_to_blocks.get(request_id) or []
|
||||
for manager in self.single_type_managers
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
pass
|
||||
|
||||
|
||||
class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator to use if prefix caching is disabled or unsupported.
|
||||
In contrast to UnitaryKVCacheCoordinator and HybridKVCacheCoordinator,
|
||||
supports arbitrary numbers of KV cache groups (including 0 groups).
|
||||
Does not implement any features related to prefix caching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
metrics_collector: KVCacheMetricsCollector | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
False,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
return [0] * self.num_single_type_manager
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(self.num_single_type_manager)
|
||||
)
|
||||
return blocks, 0
|
||||
|
||||
|
||||
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for models with only one KV cache group. This is the
|
||||
case for models with only one KV cache type, e.g., all attention layers use
|
||||
full attention or all attention layers use sliding window attention.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
metrics_collector: KVCacheMetricsCollector | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
self.pcp_world_size = pcp_world_size
|
||||
if dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
if pcp_world_size > 1:
|
||||
self.block_size *= pcp_world_size
|
||||
# For models using only Mamba, block_size is set to max_model_len when
|
||||
# prefix caching is disabled, and hash_block_size validation is skipped.
|
||||
assert not enable_caching or (hash_block_size == self.block_size), (
|
||||
"UnitaryKVCacheCoordinator assumes hash_block_size == block_size"
|
||||
)
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||
)
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
|
||||
block_hashes=block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=[0],
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.block_size,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
|
||||
class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
"""
|
||||
KV cache coordinator for hybrid models with multiple KV cache types, and
|
||||
thus multiple kv cache groups.
|
||||
To simplify `find_longest_cache_hit`, it only supports the combination of
|
||||
two types of KV cache groups, and one of them must be full attention.
|
||||
May extend to more general cases in the future.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
metrics_collector: KVCacheMetricsCollector | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
# hash_block_size: the block size used to compute block hashes.
|
||||
# The actual block size usually equals hash_block_size, but in cases where
|
||||
# different KV cache groups have different block sizes, the actual block size
|
||||
# can be a multiple of hash_block_size.
|
||||
self.hash_block_size = hash_block_size
|
||||
assert all(
|
||||
g.kv_cache_spec.block_size % hash_block_size == 0
|
||||
for g in kv_cache_config.kv_cache_groups
|
||||
), "block_size must be divisible by hash_block_size"
|
||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
"""
|
||||
Verifies that the model has exactly two types of KV cache groups, and
|
||||
one of them is full attention. Then, split the kv cache groups into full
|
||||
attention groups and other groups.
|
||||
"""
|
||||
full_attention_spec: FullAttentionSpec | None = None
|
||||
other_spec: KVCacheSpec | None = None
|
||||
self.full_attention_group_ids: list[int] = []
|
||||
self.other_group_ids: list[int] = []
|
||||
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(g.kv_cache_spec, FullAttentionSpec):
|
||||
if full_attention_spec is None:
|
||||
full_attention_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert full_attention_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of "
|
||||
"full attention groups now."
|
||||
)
|
||||
self.full_attention_group_ids.append(i)
|
||||
else:
|
||||
if other_spec is None:
|
||||
other_spec = g.kv_cache_spec
|
||||
else:
|
||||
assert other_spec == g.kv_cache_spec, (
|
||||
"HybridKVCacheCoordinator assumes "
|
||||
"exactly one other type of groups now."
|
||||
)
|
||||
self.other_group_ids.append(i)
|
||||
|
||||
assert full_attention_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of full "
|
||||
"attention groups now."
|
||||
)
|
||||
assert other_spec is not None, (
|
||||
"HybridKVCacheCoordinator assumes exactly one type of other groups now."
|
||||
)
|
||||
|
||||
self.full_attention_manager_cls = FullAttentionManager
|
||||
self.other_attention_cls = self.single_type_managers[
|
||||
self.other_group_ids[0]
|
||||
].__class__
|
||||
self.full_attention_spec = full_attention_spec
|
||||
self.other_spec = other_spec
|
||||
self.full_attention_block_size = self.full_attention_spec.block_size
|
||||
self.other_block_size = self.other_spec.block_size
|
||||
# The LCM of the block sizes of full attention and other attention.
|
||||
# The cache hit length must be a multiple of the LCM of the block sizes
|
||||
# to make sure the cache hit length is a multiple of the block size of
|
||||
# each attention type. Requiring this because we don't support partial
|
||||
# block cache hit yet.
|
||||
self.lcm_block_size = lcm(self.full_attention_block_size, self.other_block_size)
|
||||
|
||||
if max(self.full_attention_group_ids) < min(self.other_group_ids):
|
||||
self.full_attn_first = True
|
||||
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
|
||||
self.full_attn_first = False
|
||||
else:
|
||||
raise ValueError(
|
||||
"HybridKVCacheCoordinator assumes the full "
|
||||
"attention group ids and other attention group ids "
|
||||
"do not interleave, either full attention group ids "
|
||||
"are before other attention group ids or vice versa."
|
||||
"This is for simplifying merging hit_blocks_full_attn and "
|
||||
"hit_blocks_other_attn to hit_blocks."
|
||||
)
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self,
|
||||
block_hashes: list[BlockHash],
|
||||
max_cache_hit_length: int,
|
||||
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
|
||||
"""
|
||||
Find the longest cache hit for the request.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_cache_hit_length: The maximum length of the cache hit.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of the cache hit blocks for each single type manager.
|
||||
- The number of tokens of the longest cache hit.
|
||||
"""
|
||||
# First, find the longest cache hit for full attention.
|
||||
if self.full_attention_spec.block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
full_attention_block_hashes: BlockHashList = block_hashes
|
||||
else:
|
||||
# block_size is a multiple of hash_block_size. This happens when different
|
||||
# KV cache groups have different block sizes. In this case, we need to
|
||||
# recalculate block_hashes at the granularity of block_size, using the
|
||||
# original block_hashes (at the granularity of hash_block_size).
|
||||
full_attention_block_hashes = BlockHashListWithBlockSize(
|
||||
block_hashes, self.hash_block_size, self.full_attention_spec.block_size
|
||||
)
|
||||
hit_blocks_full_attn = self.full_attention_manager_cls.find_longest_cache_hit(
|
||||
block_hashes=full_attention_block_hashes,
|
||||
max_length=max_cache_hit_length,
|
||||
kv_cache_group_ids=self.full_attention_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.full_attention_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.lcm_block_size,
|
||||
)
|
||||
hit_length = len(hit_blocks_full_attn[0]) * self.full_attention_block_size
|
||||
|
||||
# Next, find the cache hit for the other attention WITHIN
|
||||
# the cache hit of full attention.
|
||||
if self.other_spec.block_size == self.hash_block_size:
|
||||
# Common case.
|
||||
other_block_hashes: BlockHashList = block_hashes
|
||||
else:
|
||||
# Similar to the full attention case, here we need to recalculate
|
||||
# block_hashes at the granularity of block_size, using the original
|
||||
# block_hashes (at the granularity of hash_block_size).
|
||||
other_block_hashes = BlockHashListWithBlockSize(
|
||||
block_hashes, self.hash_block_size, self.other_spec.block_size
|
||||
)
|
||||
hit_blocks_other_attn = self.other_attention_cls.find_longest_cache_hit(
|
||||
block_hashes=other_block_hashes,
|
||||
max_length=hit_length,
|
||||
kv_cache_group_ids=self.other_group_ids,
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_spec=self.other_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
alignment_tokens=self.lcm_block_size,
|
||||
)
|
||||
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
|
||||
|
||||
# NOTE: the prefix cache hit length must be a multiple of block_size as
|
||||
# we don't support partial block cache hit yet. The cache hit length
|
||||
# of other attention is ensured to be a multiple of the block size of
|
||||
# full attention layers in current implementation, because hit_length is
|
||||
# a multiple of other attention's block size, and other attention's
|
||||
# block size is a multiple of full attention's block size (verified in
|
||||
# `verify_and_split_kv_cache_groups`).
|
||||
assert hit_length % self.full_attention_block_size == 0
|
||||
|
||||
# Truncate the full attention cache hit to the length of the
|
||||
# cache hit of the other attention.
|
||||
for group_hit_blocks in hit_blocks_full_attn:
|
||||
del group_hit_blocks[hit_length // self.full_attention_block_size :]
|
||||
|
||||
# Merge the hit blocks of full attention and other attention.
|
||||
if self.full_attn_first:
|
||||
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
|
||||
else:
|
||||
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
|
||||
return hit_blocks, hit_length
|
||||
|
||||
|
||||
def get_kv_cache_coordinator(
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
use_eagle: bool,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
hash_block_size: int,
|
||||
metrics_collector: KVCacheMetricsCollector | None = None,
|
||||
) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
return HybridKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
max_model_len,
|
||||
use_eagle,
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
metrics_collector=metrics_collector,
|
||||
)
|
||||
419
vllm/v1/core/kv_cache_manager.py
Normal file
419
vllm/v1/core/kv_cache_manager.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, overload
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_metrics import KVCacheMetricsCollector
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
"""
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
structure from the Scheduler.
|
||||
"""
|
||||
|
||||
blocks: tuple[Sequence[KVCacheBlock], ...]
|
||||
"""
|
||||
`blocks[i][j]` refers to the i-th kv_cache_group
|
||||
and the j-th block of tokens.We don't use block of
|
||||
tokens as the outer dimension because it assumes all
|
||||
kv_cache_groups have the same number of blocks, which is true for now but
|
||||
will be broken if we want to give different block_size to different
|
||||
kv_cache_groups in the future.
|
||||
|
||||
Each single type KVCacheBlocks could be represented as:
|
||||
- list[KVCacheBlock] for more than one KVCacheBlock
|
||||
- an empty tuple for requests without KVCacheBlock
|
||||
(a precomputed KVCacheBlocks is in KVCacheManager to avoid GC overhead)
|
||||
"""
|
||||
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(
|
||||
tuple(
|
||||
list(itertools.chain(blk1, blk2))
|
||||
for blk1, blk2 in zip(self.blocks, other.blocks)
|
||||
)
|
||||
)
|
||||
|
||||
@overload
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: Literal[False] = False,
|
||||
) -> tuple[list[int], ...]: ...
|
||||
|
||||
@overload
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: Literal[True] = True,
|
||||
) -> tuple[list[int], ...] | None: ...
|
||||
|
||||
def get_block_ids(
|
||||
self,
|
||||
allow_none: bool = False,
|
||||
) -> tuple[list[int], ...] | None:
|
||||
"""
|
||||
Converts the KVCacheBlocks instance to block_ids.
|
||||
|
||||
Returns:
|
||||
tuple[list[int], ...]: A tuple of lists where:
|
||||
- the outer tuple corresponds to KV cache groups
|
||||
- each inner list contains the block_ids of the blocks in that
|
||||
group
|
||||
"""
|
||||
if allow_none and all(len(group) == 0 for group in self.blocks):
|
||||
return None
|
||||
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
assert len(self.blocks) == 1, "Only one group is supported"
|
||||
return [block.block_id for block in self.blocks[0] if block.block_hash is None]
|
||||
|
||||
def new_empty(self) -> "KVCacheBlocks":
|
||||
"""
|
||||
Creates a new KVCacheBlocks instance with no blocks.
|
||||
"""
|
||||
return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))))
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
hash_block_size: int,
|
||||
enable_caching: bool = True,
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
metrics_collector: KVCacheMetricsCollector | None = None,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
self.metrics_collector = metrics_collector
|
||||
# FIXME: make prefix cache stats conditional on log_stats. We still need
|
||||
# this comment because when the log stats is enabled there are still
|
||||
# potential configs we could expose in the future.
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_caching=self.enable_caching,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
hash_block_size=hash_block_size,
|
||||
metrics_collector=self.metrics_collector,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Pre-constructed KVCacheBlocks with no blocks, callers should use this
|
||||
# via create_kv_cache_blocks instead of creating new ones to avoid GC
|
||||
# overhead.
|
||||
#
|
||||
# We use nested tuples to ensure the empty KVCacheBlocks is immutable.
|
||||
self.empty_kv_cache_blocks = KVCacheBlocks(
|
||||
tuple(() for _ in range(self.num_kv_cache_groups))
|
||||
)
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return self.block_pool.get_usage()
|
||||
|
||||
def make_prefix_cache_stats(self) -> PrefixCacheStats | None:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]:
|
||||
"""Get the computed (cached) blocks for the request.
|
||||
Note that the computed blocks must be full.
|
||||
|
||||
Args:
|
||||
request: The request to get the computed blocks.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of blocks that are computed for the request.
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
# We skip finding the prefix cache hit when prefix caching is
|
||||
# disabled or the request is marked as skipping kv cache read
|
||||
# (which happens when the request requires prompt logprobs
|
||||
# or calls a pooling model with all pooling).
|
||||
if not self.enable_caching or request.skip_reading_prefix_cache:
|
||||
return self.empty_kv_cache_blocks, 0
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
# the single last token, because allocate_slots() requires
|
||||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(
|
||||
request.block_hashes, max_cache_hit_length
|
||||
)
|
||||
)
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.record(
|
||||
num_tokens=request.num_tokens,
|
||||
num_hits=num_new_computed_tokens,
|
||||
preempted=request.num_preemptions > 0,
|
||||
)
|
||||
|
||||
return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_new_tokens: int,
|
||||
num_new_computed_tokens: int = 0,
|
||||
new_computed_blocks: KVCacheBlocks | None = None,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
num_encoder_tokens: int = 0,
|
||||
) -> KVCacheBlocks | None:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
Args:
|
||||
request: The request to allocate slots.
|
||||
num_new_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
num_new_computed_tokens: The number of new computed tokens just
|
||||
hitting the prefix caching, excluding external tokens.
|
||||
new_computed_blocks: The cached blocks for the above new computed
|
||||
tokens.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||
used by P/D when allocating blocks used in a KV transfer
|
||||
which will complete in a future step.
|
||||
num_encoder_tokens: The number of encoder tokens to allocate for
|
||||
cross-attention in encoder-decoder models(e.g., Whisper).
|
||||
For decoder-only models, this should be 0.
|
||||
|
||||
Blocks layout:
|
||||
```
|
||||
-----------------------------------------------------------------------
|
||||
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||||
-----------------------------------------------------------------------
|
||||
| < required > |
|
||||
--------------------------------------------------
|
||||
| < full > |
|
||||
------------------------------------------------
|
||||
| <new full> |
|
||||
--------------
|
||||
```
|
||||
The following *_blocks are illustrated in this layout.
|
||||
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
if num_new_tokens == 0:
|
||||
raise ValueError("num_new_tokens must be greater than 0")
|
||||
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = self.empty_kv_cache_blocks.blocks
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.coordinator.remove_skipped_blocks(
|
||||
request.request_id, request.num_computed_tokens
|
||||
)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = request.num_computed_tokens + num_new_computed_tokens
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len,
|
||||
)
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
num_encoder_tokens=num_encoder_tokens,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not any(new_computed_block_list), (
|
||||
"Computed blocks should be empty when prefix caching is disabled"
|
||||
)
|
||||
|
||||
if new_computed_block_list is not self.empty_kv_cache_blocks.blocks:
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(
|
||||
request.request_id, new_computed_block_list
|
||||
)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot, num_encoder_tokens
|
||||
)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
# remote. Update state for locally cached blocks.
|
||||
if not self.enable_caching or delay_cache_blocks:
|
||||
return self.create_kv_cache_blocks(new_blocks)
|
||||
|
||||
# NOTE(woosuk): We want to commit (cache) up to num_computed_tokens +
|
||||
# num_new_tokens, but must exclude "non-committable" tokens (e.g.,
|
||||
# draft tokens that could be rejected). Therefore, we cap the number
|
||||
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
|
||||
num_tokens_to_cache = min(
|
||||
num_computed_tokens + num_new_tokens, request.num_tokens
|
||||
)
|
||||
self.coordinator.cache_blocks(request, num_tokens_to_cache)
|
||||
|
||||
return self.create_kv_cache_blocks(new_blocks)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free the blocks allocated for the request.
|
||||
We free the blocks in reverse order so that the tail blocks are evicted
|
||||
first when caching is enabled.
|
||||
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def evict_blocks(self, block_ids: set[int]) -> None:
|
||||
"""evict blocks from the prefix cache by their block IDs.
|
||||
|
||||
Args:
|
||||
block_ids: Set of block IDs to evict from cache.
|
||||
"""
|
||||
self.block_pool.evict_blocks(block_ids)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalidate prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
if not self.block_pool.reset_prefix_cache():
|
||||
return False
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.reset = True
|
||||
return True
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> list[int]:
|
||||
"""Calculate the number of common prefix blocks for each kv cache group.
|
||||
|
||||
The function selects a running request and iterates through its blocks.
|
||||
A block is considered a common prefix block if ALL requests with
|
||||
allocated KV cache share it (i.e., ref_cnt equals the number of entries
|
||||
in req_to_blocks).
|
||||
|
||||
NOTE(woosuk): The number of requests with allocated KV cache is **greater
|
||||
than or equal to** the number of requests scheduled in the current step.
|
||||
This is because having allocated KV cache only indicates that:
|
||||
1. The request has not yet finished, and
|
||||
2. The request holds its blocks unfreed.
|
||||
|
||||
While all scheduled requests must have allocated KV cache, the inverse
|
||||
is not necessarily true. There may be requests with allocated KV cache
|
||||
that are not scheduled in the current step.
|
||||
|
||||
This can result in an edge case where the number of common prefix blocks
|
||||
is 0, even though all scheduled requests share a common prefix. This
|
||||
occurs because there may be unscheduled requests that do not share the
|
||||
common prefix. Currently, this case cannot be easily detected, so the
|
||||
function returns 0 in such cases.
|
||||
|
||||
Args:
|
||||
running_request_id: The request ID of any running request, used to
|
||||
identify the common prefix blocks.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for each kv cache
|
||||
group.
|
||||
"""
|
||||
return self.coordinator.get_num_common_prefix_blocks(running_request_id)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_blocks(self, request_id: str) -> KVCacheBlocks:
|
||||
"""Get the blocks of a request."""
|
||||
return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id))
|
||||
|
||||
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||||
"""Get the block ids of a request."""
|
||||
return self.get_blocks(request_id).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
if self.enable_caching:
|
||||
self.coordinator.cache_blocks(request, num_computed_tokens)
|
||||
|
||||
def create_kv_cache_blocks(
|
||||
self, blocks: tuple[list[KVCacheBlock], ...]
|
||||
) -> KVCacheBlocks:
|
||||
# Only create new KVCacheBlocks for non-empty blocks
|
||||
return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks
|
||||
96
vllm/v1/core/kv_cache_metrics.py
Normal file
96
vllm/v1/core/kv_cache_metrics.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""KV cache metrics tracking."""
|
||||
|
||||
import random
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.kv_cache_utils import KVCacheBlock
|
||||
|
||||
from vllm.v1.metrics.stats import KVCacheEvictionEvent
|
||||
|
||||
|
||||
class BlockMetricsState:
|
||||
"""Tracks lifecycle metrics for a single KV cache block."""
|
||||
|
||||
def __init__(self):
|
||||
now_ns = time.monotonic_ns()
|
||||
self.birth_time_ns = now_ns
|
||||
self.last_access_ns = now_ns
|
||||
# Bounded to prevent unbounded growth if a block is accessed many times.
|
||||
self.access_history: deque[int] = deque(maxlen=4)
|
||||
|
||||
def record_access(self) -> None:
|
||||
now_ns = time.monotonic_ns()
|
||||
self.last_access_ns = now_ns
|
||||
self.access_history.append(now_ns)
|
||||
|
||||
def get_lifetime_seconds(self) -> float:
|
||||
now_ns = time.monotonic_ns()
|
||||
return (now_ns - self.birth_time_ns) / 1e9
|
||||
|
||||
def get_idle_time_seconds(self) -> float:
|
||||
now_ns = time.monotonic_ns()
|
||||
return (now_ns - self.last_access_ns) / 1e9
|
||||
|
||||
def get_reuse_gaps_seconds(self) -> list[float]:
|
||||
if len(self.access_history) < 2:
|
||||
return []
|
||||
history = list(self.access_history)
|
||||
return [(history[i] - history[i - 1]) / 1e9 for i in range(1, len(history))]
|
||||
|
||||
|
||||
class KVCacheMetricsCollector:
|
||||
"""Collects KV cache residency metrics with sampling."""
|
||||
|
||||
def __init__(self, sample_rate: float = 0.01):
|
||||
assert 0 < sample_rate <= 1.0, (
|
||||
f"sample_rate must be in (0, 1.0], got {sample_rate}"
|
||||
)
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
self.block_metrics: dict[int, BlockMetricsState] = {}
|
||||
|
||||
self._eviction_events: list[KVCacheEvictionEvent] = []
|
||||
|
||||
def should_sample_block(self) -> bool:
|
||||
return random.random() < self.sample_rate
|
||||
|
||||
def on_block_allocated(self, block: "KVCacheBlock") -> None:
|
||||
if self.should_sample_block():
|
||||
self.block_metrics[block.block_id] = BlockMetricsState()
|
||||
|
||||
def on_block_accessed(self, block: "KVCacheBlock") -> None:
|
||||
metrics = self.block_metrics.get(block.block_id)
|
||||
if metrics:
|
||||
metrics.record_access()
|
||||
|
||||
def on_block_evicted(self, block: "KVCacheBlock") -> None:
|
||||
metrics = self.block_metrics.pop(block.block_id, None)
|
||||
if not metrics:
|
||||
return
|
||||
|
||||
lifetime = metrics.get_lifetime_seconds()
|
||||
idle_time = metrics.get_idle_time_seconds()
|
||||
reuse_gaps = tuple(metrics.get_reuse_gaps_seconds())
|
||||
|
||||
self._eviction_events.append(
|
||||
KVCacheEvictionEvent(
|
||||
lifetime_seconds=lifetime,
|
||||
idle_seconds=idle_time,
|
||||
reuse_gaps_seconds=reuse_gaps,
|
||||
)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all state on cache reset."""
|
||||
self.block_metrics.clear()
|
||||
self._eviction_events.clear()
|
||||
|
||||
def drain_events(self) -> list[KVCacheEvictionEvent]:
|
||||
events = self._eviction_events
|
||||
self._eviction_events = []
|
||||
return events
|
||||
1476
vllm/v1/core/kv_cache_utils.py
Normal file
1476
vllm/v1/core/kv_cache_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
0
vllm/v1/core/sched/__init__.py
Normal file
0
vllm/v1/core/sched/__init__.py
Normal file
68
vllm/v1/core/sched/async_scheduler.py
Normal file
68
vllm/v1/core/sched/async_scheduler.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncScheduler(Scheduler):
|
||||
def _update_after_schedule(
|
||||
self,
|
||||
scheduler_output: SchedulerOutput,
|
||||
) -> None:
|
||||
super()._update_after_schedule(scheduler_output)
|
||||
pending_structured_output_tokens = False
|
||||
spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
for req_id in scheduler_output.num_scheduled_tokens:
|
||||
request = self.requests[req_id]
|
||||
pending_structured_output_tokens |= (
|
||||
request.use_structured_output and request.num_output_placeholders > 0
|
||||
)
|
||||
cur_num_spec_tokens = len(spec_decode_tokens.get(req_id, ()))
|
||||
if (
|
||||
request.num_computed_tokens
|
||||
== request.num_tokens
|
||||
+ request.num_output_placeholders
|
||||
+ cur_num_spec_tokens
|
||||
):
|
||||
# The request will generate a new token plus num_spec_tokens
|
||||
# in this scheduling step.
|
||||
request.num_output_placeholders += 1 + cur_num_spec_tokens
|
||||
# Add placeholders for the new tokens in spec_token_ids.
|
||||
# We will update the actual spec token ids in the worker process.
|
||||
request.spec_token_ids = [-1] * self.num_spec_tokens
|
||||
|
||||
scheduler_output.pending_structured_output_tokens = (
|
||||
pending_structured_output_tokens
|
||||
)
|
||||
|
||||
def _update_request_with_output(
|
||||
self,
|
||||
request: Request,
|
||||
new_token_ids: list[int],
|
||||
) -> tuple[list[int], bool]:
|
||||
if request.discard_latest_async_tokens:
|
||||
# If the request is force preempted in reset_prefix_cache, we
|
||||
# should discard the latest async token.
|
||||
request.discard_latest_async_tokens = False
|
||||
return [], False
|
||||
|
||||
status_before_update = request.status
|
||||
new_token_ids, stopped = super()._update_request_with_output(
|
||||
request, new_token_ids
|
||||
)
|
||||
|
||||
# Update the number of output placeholders.
|
||||
request.num_output_placeholders -= len(new_token_ids)
|
||||
assert request.num_output_placeholders >= 0
|
||||
|
||||
# Cache the new tokens. Preempted requests should be skipped.
|
||||
if status_before_update == RequestStatus.RUNNING:
|
||||
self.kv_cache_manager.cache_blocks(
|
||||
request, request.num_computed_tokens - request.num_output_placeholders
|
||||
)
|
||||
return new_token_ids, stopped
|
||||
189
vllm/v1/core/sched/interface.py
Normal file
189
vllm/v1/core/sched/interface.py
Normal file
@@ -0,0 +1,189 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
|
||||
class SchedulerInterface(ABC):
|
||||
@abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
structured_output_manager: "StructuredOutputManager",
|
||||
block_size: int,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
include_finished_set: bool = False,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def schedule(self) -> "SchedulerOutput":
|
||||
"""Schedule the requests to process in this scheduling step.
|
||||
|
||||
The scheduling decision is made at the iteration level. Each scheduling
|
||||
step corresponds to a single forward pass of the model. Therefore, this
|
||||
method is called repeatedly by a busy loop in the engine.
|
||||
|
||||
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
|
||||
that specifies how many tokens to process for each request in this
|
||||
scheduling step. For example, num_tokens can be as large as the number
|
||||
of prompt tokens for new requests, or it can be 1 for the requests that
|
||||
are auto-regressively generating new tokens one by one. Otherwise, it
|
||||
can be somewhere in between in case of chunked prefills, prefix caching,
|
||||
speculative decoding, etc.
|
||||
|
||||
Additionally, the scheduler also returns useful data about each request
|
||||
or the batch as a whole. The model runner will use this information in
|
||||
preparing inputs to the model.
|
||||
|
||||
Returns:
|
||||
A SchedulerOutput object containing information about the scheduled
|
||||
requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_grammar_bitmask(
|
||||
self, scheduler_output: "SchedulerOutput"
|
||||
) -> "GrammarOutput | None":
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_from_output(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
model_runner_output: "ModelRunnerOutput",
|
||||
) -> dict[int, "EngineCoreOutputs"]:
|
||||
"""Update the scheduler state based on the model runner output.
|
||||
|
||||
This method is called after the model runner has processed the scheduled
|
||||
requests. The model runner output includes generated token ids, draft
|
||||
token ids for next step, etc. The scheduler uses this information to
|
||||
update its states, checks the finished requests, and returns the output
|
||||
for each request.
|
||||
|
||||
Returns:
|
||||
A dict of client index to EngineCoreOutputs object containing the
|
||||
outputs for each request originating from that client.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_draft_token_ids(
|
||||
self,
|
||||
draft_token_ids: "DraftTokenIds",
|
||||
) -> None:
|
||||
"""Update the draft token ids for the scheduled requests."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, request: "Request") -> None:
|
||||
"""Add a new request to the scheduler's internal queue.
|
||||
|
||||
Args:
|
||||
request: The new request being added.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def finish_requests(
|
||||
self,
|
||||
request_ids: str | Iterable[str],
|
||||
finished_status: "RequestStatus",
|
||||
) -> None:
|
||||
"""Finish the requests in the scheduler's internal queue. If the request
|
||||
is not in the queue, this method will do nothing.
|
||||
|
||||
This method is called in two cases:
|
||||
1. When the request is aborted by the client.
|
||||
2. When the frontend process detects a stop string of the request after
|
||||
de-tokenizing its generated tokens.
|
||||
|
||||
Args:
|
||||
request_ids: A single or a list of request IDs.
|
||||
finished_status: The finished status of the given requests.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_num_unfinished_requests(self) -> int:
|
||||
"""Number of unfinished requests in the scheduler's internal queue."""
|
||||
raise NotImplementedError
|
||||
|
||||
def has_unfinished_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests in the scheduler's
|
||||
internal queue."""
|
||||
return self.get_num_unfinished_requests() > 0
|
||||
|
||||
@abstractmethod
|
||||
def has_finished_requests(self) -> bool:
|
||||
"""Returns True if there are finished requests that need to be cleared.
|
||||
NOTE: This is different from `not self.has_unfinished_requests()`.
|
||||
|
||||
The scheduler maintains an internal list of the requests finished in the
|
||||
previous step. This list is returned from the next call to schedule(),
|
||||
to be sent to the model runner in the next step to clear cached states
|
||||
for these finished requests.
|
||||
|
||||
This method checks if this internal list of finished requests is
|
||||
non-empty. This information is useful for DP attention.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def has_requests(self) -> bool:
|
||||
"""Returns True if there are unfinished requests, or finished requests
|
||||
not yet returned in SchedulerOutputs."""
|
||||
return self.has_unfinished_requests() or self.has_finished_requests()
|
||||
|
||||
@abstractmethod
|
||||
def reset_prefix_cache(
|
||||
self, reset_running_requests: bool = False, reset_connector: bool = False
|
||||
) -> bool:
|
||||
"""Reset the prefix cache for KV cache.
|
||||
|
||||
This is particularly required when the model weights are live-updated.
|
||||
|
||||
Args:
|
||||
reset_running_requests: If True, all the running requests will be
|
||||
preempted and moved to the waiting queue. Otherwise, this method
|
||||
will only reset the KV prefix cache when there is no running request
|
||||
taking KV cache.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def make_stats(self) -> Optional["SchedulerStats"]:
|
||||
"""Make a SchedulerStats object for logging.
|
||||
|
||||
The SchedulerStats object is created for every scheduling step.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
|
||||
return None
|
||||
230
vllm/v1/core/sched/output.py
Normal file
230
vllm/v1/core/sched/output.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from vllm._bc_linter import bc_linter_include
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
|
||||
from vllm.distributed.ec_transfer.ec_connector.base import ECConnectorMetadata
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalFeatureSpec
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
ECConnectorMetadata = object
|
||||
KVConnectorMetadata = object
|
||||
LoRARequest = object
|
||||
MultiModalFeatureSpec = object
|
||||
PoolingParams = object
|
||||
SamplingParams = object
|
||||
Request = object
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class NewRequestData:
|
||||
req_id: str
|
||||
prompt_token_ids: list[int] | None
|
||||
mm_features: list[MultiModalFeatureSpec]
|
||||
sampling_params: SamplingParams | None
|
||||
pooling_params: PoolingParams | None
|
||||
block_ids: tuple[list[int], ...]
|
||||
num_computed_tokens: int
|
||||
lora_request: LoRARequest | None
|
||||
prompt_embeds: "torch.Tensor | None" = None
|
||||
|
||||
# Only used for v2 model runner.
|
||||
prefill_token_ids: list[int] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: tuple[list[int], ...],
|
||||
prefill_token_ids: list[int] | None = None,
|
||||
) -> "NewRequestData":
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
mm_features=request.mm_features,
|
||||
sampling_params=request.sampling_params,
|
||||
pooling_params=request.pooling_params,
|
||||
block_ids=block_ids,
|
||||
num_computed_tokens=request.num_computed_tokens,
|
||||
lora_request=request.lora_request,
|
||||
prompt_embeds=request.prompt_embeds,
|
||||
prefill_token_ids=prefill_token_ids,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
prompt_embeds_shape = (
|
||||
self.prompt_embeds.shape if self.prompt_embeds is not None else None
|
||||
)
|
||||
return (
|
||||
f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids={self.prompt_token_ids},"
|
||||
f"prefill_token_ids={self.prefill_token_ids},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")"
|
||||
)
|
||||
|
||||
# Version of __repr__ with the prompt data obfuscated
|
||||
def anon_repr(self) -> str:
|
||||
prompt_token_ids_len = (
|
||||
len(self.prompt_token_ids) if self.prompt_token_ids is not None else None
|
||||
)
|
||||
prompt_embeds_shape = (
|
||||
self.prompt_embeds.shape if self.prompt_embeds is not None else None
|
||||
)
|
||||
return (
|
||||
f"NewRequestData("
|
||||
f"req_id={self.req_id},"
|
||||
f"prompt_token_ids_len={prompt_token_ids_len},"
|
||||
f"mm_features={self.mm_features},"
|
||||
f"sampling_params={self.sampling_params},"
|
||||
f"block_ids={self.block_ids},"
|
||||
f"num_computed_tokens={self.num_computed_tokens},"
|
||||
f"lora_request={self.lora_request},"
|
||||
f"prompt_embeds_shape={prompt_embeds_shape}"
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class CachedRequestData:
|
||||
req_ids: list[str]
|
||||
# For request ids not in resumed_req_ids, new_block_ids will be appended to
|
||||
# the request's block IDs. For those in the set, new_block_ids will be used as the
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_req_ids: set[str]
|
||||
# NOTE(woosuk): new_token_ids is only used for pipeline parallelism.
|
||||
# When PP is not used, new_token_ids will be empty.
|
||||
new_token_ids: list[list[int]]
|
||||
# For requests not scheduled in the last step, propagate the token ids to the
|
||||
# connector. Won't contain requests that were scheduled in the prior step.
|
||||
all_token_ids: dict[str, list[int]]
|
||||
new_block_ids: list[tuple[list[int], ...] | None]
|
||||
num_computed_tokens: list[int]
|
||||
num_output_tokens: list[int]
|
||||
|
||||
@property
|
||||
def num_reqs(self) -> int:
|
||||
return len(self.req_ids)
|
||||
|
||||
@cached_property
|
||||
@deprecated("This will be removed in v0.14, use `resumed_req_ids` instead.")
|
||||
def resumed_from_preemption(self) -> list[bool]:
|
||||
return [req_id in self.resumed_req_ids for req_id in self.req_ids]
|
||||
|
||||
@cached_property
|
||||
@deprecated("This will be removed in v0.14, use `all_token_ids` instead.")
|
||||
def resumed_req_token_ids(self) -> list[list[int] | None]:
|
||||
return [
|
||||
self.all_token_ids[req_id] if req_id in self.resumed_req_ids else None
|
||||
for req_id in self.req_ids
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "CachedRequestData":
|
||||
return cls(
|
||||
req_ids=[],
|
||||
resumed_req_ids=set(),
|
||||
new_token_ids=[],
|
||||
all_token_ids={},
|
||||
new_block_ids=[],
|
||||
num_computed_tokens=[],
|
||||
num_output_tokens=[],
|
||||
)
|
||||
|
||||
|
||||
@bc_linter_include
|
||||
@dataclass
|
||||
class SchedulerOutput:
|
||||
# list of the requests that are scheduled for the first time.
|
||||
# We cache the request's data in each worker process, so that we don't
|
||||
# need to re-send it every scheduling step.
|
||||
scheduled_new_reqs: list[NewRequestData]
|
||||
# list of the requests that have been scheduled before.
|
||||
# Since the request's data is already cached in the worker processes,
|
||||
# we only send the diff to minimize the communication cost.
|
||||
scheduled_cached_reqs: CachedRequestData
|
||||
|
||||
# req_id -> num_scheduled_tokens
|
||||
# Number of tokens scheduled for each request.
|
||||
num_scheduled_tokens: dict[str, int]
|
||||
# Total number of tokens scheduled for all requests.
|
||||
# Equal to sum(num_scheduled_tokens.values())
|
||||
total_num_scheduled_tokens: int
|
||||
# req_id -> spec_token_ids
|
||||
# If a request does not have any spec decode tokens, it will not be
|
||||
# included in the dictionary.
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]]
|
||||
# req_id -> encoder input indices that need processing.
|
||||
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
||||
# to process that the request's 0-th and 1-th images in the current step.
|
||||
scheduled_encoder_inputs: dict[str, list[int]]
|
||||
# Number of common prefix blocks for all requests in each KV cache group.
|
||||
# This can be used for cascade attention.
|
||||
num_common_prefix_blocks: list[int]
|
||||
|
||||
# Request IDs that are finished in between the previous and the current
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
# so that they can free the cached states for those requests.
|
||||
finished_req_ids: set[str]
|
||||
# list of mm_hash strings associated with the encoder outputs to be
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Request IDs that are preempted in this step.
|
||||
# Only used for v2 model runner.
|
||||
preempted_req_ids: set[str] | None = None
|
||||
|
||||
# Whether the scheduled requests have all the output tokens they
|
||||
# need to perform grammar bitmask computation.
|
||||
pending_structured_output_tokens: bool = False
|
||||
|
||||
# KV Cache Connector metadata.
|
||||
kv_connector_metadata: KVConnectorMetadata | None = None
|
||||
|
||||
# EC Cache Connector metadata
|
||||
ec_connector_metadata: ECConnectorMetadata | None = None
|
||||
|
||||
@classmethod
|
||||
def make_empty(cls) -> "SchedulerOutput":
|
||||
return cls(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=CachedRequestData.make_empty(),
|
||||
num_scheduled_tokens={},
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GrammarOutput:
|
||||
# ids of structured output requests.
|
||||
structured_output_request_ids: list[str]
|
||||
# Bitmask ordered as structured_output_request_ids.
|
||||
grammar_bitmask: "npt.NDArray[np.int32]"
|
||||
217
vllm/v1/core/sched/request_queue.py
Normal file
217
vllm/v1/core/sched/request_queue.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import heapq
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from enum import Enum
|
||||
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class SchedulingPolicy(Enum):
|
||||
"""Enum for scheduling policies."""
|
||||
|
||||
FCFS = "fcfs"
|
||||
PRIORITY = "priority"
|
||||
|
||||
|
||||
class RequestQueue(ABC):
|
||||
"""Abstract base class for request queues."""
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the request at the front of the queue without removing it."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Prepend a request to the front of the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepend_requests(self, requests: "RequestQueue") -> None:
|
||||
"""Prepend all requests from another queue to the front of this
|
||||
queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to the policy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse order."""
|
||||
pass
|
||||
|
||||
|
||||
class FCFSRequestQueue(deque[Request], RequestQueue):
|
||||
"""A first-come-first-served queue that supports deque operations."""
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to FCFS policy."""
|
||||
self.append(request)
|
||||
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to FCFS policy."""
|
||||
return self.popleft()
|
||||
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the next request in the queue without removing it."""
|
||||
if not self:
|
||||
raise IndexError("peek from an empty queue")
|
||||
return self[0]
|
||||
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Prepend a request to the front of the queue."""
|
||||
self.appendleft(request)
|
||||
|
||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||
"""Prepend all requests from another queue to the front of this
|
||||
queue."""
|
||||
self.extendleft(reversed(requests))
|
||||
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
self.remove(request)
|
||||
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
requests_to_remove = set(requests)
|
||||
filtered_requests = [req for req in self if req not in requests_to_remove]
|
||||
# deque does not support in-place filtering, so we need to clear
|
||||
# and extend
|
||||
self.clear()
|
||||
self.extend(filtered_requests)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
return len(self) > 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
return super().__len__()
|
||||
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to FCFS policy."""
|
||||
return super().__iter__()
|
||||
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse order."""
|
||||
return super().__reversed__()
|
||||
|
||||
|
||||
class PriorityRequestQueue(RequestQueue):
|
||||
"""
|
||||
A priority queue that supports heap operations.
|
||||
|
||||
Respects the ordering defined in the Request class, where
|
||||
requests with a smaller value of `priority` are processed first.
|
||||
If multiple requests have the same priority, the one with the earlier
|
||||
`arrival_time` is processed first.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._heap: list[Request] = []
|
||||
|
||||
def add_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to priority policy."""
|
||||
heapq.heappush(self._heap, request)
|
||||
|
||||
def pop_request(self) -> Request:
|
||||
"""Pop a request from the queue according to priority policy."""
|
||||
if not self._heap:
|
||||
raise IndexError("pop from empty heap")
|
||||
return heapq.heappop(self._heap)
|
||||
|
||||
def peek_request(self) -> Request:
|
||||
"""Peek at the next request in the queue without removing it."""
|
||||
if not self._heap:
|
||||
raise IndexError("peek from empty heap")
|
||||
return self._heap[0]
|
||||
|
||||
def prepend_request(self, request: Request) -> None:
|
||||
"""Add a request to the queue according to priority policy.
|
||||
|
||||
Note: In a priority queue, there is no concept of prepending to the
|
||||
front. Requests are ordered by (priority, arrival_time)."""
|
||||
self.add_request(request)
|
||||
|
||||
def prepend_requests(self, requests: RequestQueue) -> None:
|
||||
"""Add all requests from another queue according to priority policy.
|
||||
|
||||
Note: In a priority queue, there is no concept of prepending to the
|
||||
front. Requests are ordered by (priority, arrival_time)."""
|
||||
for request in requests:
|
||||
self.add_request(request)
|
||||
|
||||
def remove_request(self, request: Request) -> None:
|
||||
"""Remove a specific request from the queue."""
|
||||
self._heap.remove(request)
|
||||
heapq.heapify(self._heap)
|
||||
|
||||
def remove_requests(self, requests: Iterable[Request]) -> None:
|
||||
"""Remove multiple specific requests from the queue."""
|
||||
requests_to_remove = requests if isinstance(requests, set) else set(requests)
|
||||
self._heap = [r for r in self._heap if r not in requests_to_remove]
|
||||
heapq.heapify(self._heap)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Check if queue has any requests."""
|
||||
return bool(self._heap)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of requests in queue."""
|
||||
return len(self._heap)
|
||||
|
||||
def __iter__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue according to priority policy."""
|
||||
heap_copy = self._heap[:]
|
||||
while heap_copy:
|
||||
yield heapq.heappop(heap_copy)
|
||||
|
||||
def __reversed__(self) -> Iterator[Request]:
|
||||
"""Iterate over the queue in reverse priority order."""
|
||||
return reversed(list(self))
|
||||
|
||||
|
||||
def create_request_queue(policy: SchedulingPolicy) -> RequestQueue:
|
||||
"""Create request queue based on scheduling policy."""
|
||||
if policy == SchedulingPolicy.PRIORITY:
|
||||
return PriorityRequestQueue()
|
||||
elif policy == SchedulingPolicy.FCFS:
|
||||
return FCFSRequestQueue()
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduling policy: {policy}")
|
||||
1826
vllm/v1/core/sched/scheduler.py
Normal file
1826
vllm/v1/core/sched/scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
64
vllm/v1/core/sched/utils.py
Normal file
64
vllm/v1/core/sched/utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
def remove_all(lst: list, items_to_remove: set) -> list:
|
||||
"""Remove all items from a list that are in the items_to_remove set.
|
||||
|
||||
This method optimizes for the common case of removing a single item,
|
||||
falling back to list comprehension for multiple items.
|
||||
|
||||
Args:
|
||||
lst: The list to remove items from
|
||||
items_to_remove: Set of items to remove
|
||||
|
||||
Returns:
|
||||
Either the modified original list (for single item removal) or
|
||||
a new list (for multiple item removal). Callers should use the
|
||||
returned value.
|
||||
|
||||
Note:
|
||||
For single item removal, this modifies the original list in-place
|
||||
and returns it. For multiple items, it creates and returns a new list.
|
||||
"""
|
||||
if not items_to_remove:
|
||||
return lst
|
||||
|
||||
if len(items_to_remove) == 1:
|
||||
# Fast path for single item removal (most common case)
|
||||
item = next(iter(items_to_remove))
|
||||
with contextlib.suppress(ValueError):
|
||||
lst.remove(item)
|
||||
return lst
|
||||
# For multiple items, use list comprehension
|
||||
return [item for item in lst if item not in items_to_remove]
|
||||
|
||||
|
||||
def check_stop(request: Request, max_model_len: int) -> bool:
|
||||
assert not request.pooling_params
|
||||
|
||||
sampling_params = request.sampling_params
|
||||
assert sampling_params is not None
|
||||
|
||||
if request.num_output_tokens < sampling_params.min_tokens:
|
||||
return False
|
||||
|
||||
last_token_id = request.output_token_ids[-1]
|
||||
if not sampling_params.ignore_eos and last_token_id == request.eos_token_id:
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
return True
|
||||
|
||||
if last_token_id in (sampling_params.stop_token_ids or ()):
|
||||
request.status = RequestStatus.FINISHED_STOPPED
|
||||
request.stop_reason = last_token_id
|
||||
return True
|
||||
if (
|
||||
request.num_tokens >= max_model_len
|
||||
or request.num_output_tokens >= request.max_tokens
|
||||
):
|
||||
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
return True
|
||||
return False
|
||||
801
vllm/v1/core/single_type_kv_cache_manager.py
Normal file
801
vllm/v1/core/single_type_kv_cache_manager.py
Normal file
@@ -0,0 +1,801 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashList, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
FullAttentionSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
from vllm.v1.request import Request
|
||||
|
||||
|
||||
class SingleTypeKVCacheManager(ABC):
|
||||
"""
|
||||
An abstract base class for a manager that handle the kv cache management
|
||||
logic of one specific type of attention layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SingleTypeKVCacheManager.
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
kv_cache_group_id: The id of the kv cache group of this manager.
|
||||
"""
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
self.pcp_world_size = pcp_world_size
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
# Mapping from request ID to blocks to track the blocks allocated
|
||||
# for each request, so that we can free the blocks when the request
|
||||
# is finished.
|
||||
self.req_to_blocks: defaultdict[str, list[KVCacheBlock]] = defaultdict(list)
|
||||
|
||||
# {req_id: The number of cached blocks for this given request}
|
||||
# This is used to track the number of cached blocks for each request.
|
||||
# This is only used to track the RUNNING requests, we do not track the
|
||||
# data for preempted ones.
|
||||
self.num_cached_block: dict[str, int] = {}
|
||||
|
||||
self.kv_cache_group_id = kv_cache_group_id
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
) -> int:
|
||||
"""
|
||||
Get the number of blocks needed to be allocated for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix caching.
|
||||
|
||||
Returns:
|
||||
The number of blocks.
|
||||
"""
|
||||
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = (
|
||||
num_required_blocks
|
||||
- len(new_computed_blocks)
|
||||
- len(self.req_to_blocks[request_id])
|
||||
)
|
||||
# If a computed block of a request is an eviction candidate (in the
|
||||
# free queue and ref_cnt == 0), it will be changed from a free block
|
||||
# to a computed block when the request is allocated, so we also count
|
||||
# it as needed to be allocated.
|
||||
num_evictable_computed_blocks = sum(
|
||||
blk.ref_cnt == 0 and not blk.is_null for blk in new_computed_blocks
|
||||
)
|
||||
return num_new_blocks + num_evictable_computed_blocks
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||
) -> None:
|
||||
"""
|
||||
Add the new computed blocks to the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
new_computed_blocks: The new computed blocks just hitting the
|
||||
prefix cache.
|
||||
"""
|
||||
if request_id not in self.num_cached_block:
|
||||
# A new request.
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
assert len(req_blocks) == 0
|
||||
req_blocks.extend(new_computed_blocks)
|
||||
self.num_cached_block[request_id] = len(new_computed_blocks)
|
||||
else:
|
||||
# A running request. Should not have new computed blocks.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Allocate new blocks for the request to give it at least `num_tokens`
|
||||
token slots.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_tokens: The total number of tokens that need a slot (including
|
||||
tokens that are already allocated).
|
||||
|
||||
Returns:
|
||||
The new allocated blocks.
|
||||
"""
|
||||
req_blocks = self.req_to_blocks[request_id]
|
||||
num_required_blocks = cdiv(num_tokens, self.block_size)
|
||||
num_new_blocks = num_required_blocks - len(req_blocks)
|
||||
if num_new_blocks <= 0:
|
||||
return []
|
||||
else:
|
||||
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
|
||||
req_blocks.extend(new_blocks)
|
||||
return new_blocks
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
"""
|
||||
Cache the blocks for the request.
|
||||
|
||||
Args:
|
||||
request: The request.
|
||||
num_tokens: The total number of tokens that need to be cached
|
||||
(including tokens that are already cached).
|
||||
"""
|
||||
num_cached_blocks = self.num_cached_block.get(request.request_id, 0)
|
||||
num_full_blocks = num_tokens // self.block_size
|
||||
|
||||
if num_cached_blocks >= num_full_blocks:
|
||||
return
|
||||
|
||||
self.block_pool.cache_full_blocks(
|
||||
request=request,
|
||||
blocks=self.req_to_blocks[request.request_id],
|
||||
num_cached_blocks=num_cached_blocks,
|
||||
num_full_blocks=num_full_blocks,
|
||||
block_size=self.block_size,
|
||||
kv_cache_group_id=self.kv_cache_group_id,
|
||||
)
|
||||
|
||||
self.num_cached_block[request.request_id] = num_full_blocks
|
||||
|
||||
def free(self, request_id: str) -> None:
|
||||
"""
|
||||
Free the blocks for the request.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
"""
|
||||
# Default to [] in case a request is freed (aborted) before alloc.
|
||||
req_blocks = self.req_to_blocks.pop(request_id, [])
|
||||
|
||||
# Free blocks in reverse order so that the tail blocks are
|
||||
# freed first.
|
||||
ordered_blocks = reversed(req_blocks)
|
||||
|
||||
self.block_pool.free_blocks(ordered_blocks)
|
||||
self.num_cached_block.pop(request_id, None)
|
||||
|
||||
@abstractmethod
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
Get the number of common prefix blocks for all requests with allocated
|
||||
KV cache.
|
||||
|
||||
Args:
|
||||
running_request_id: The request ID.
|
||||
|
||||
Returns:
|
||||
The number of common prefix blocks for all requests with allocated
|
||||
KV cache.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
`max_length`. The prefix should be a common prefix hit for all the
|
||||
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
|
||||
return an empty list.
|
||||
If eagle is enabled, drop the last matched block to force recompute the
|
||||
last block to get the required hidden states for eagle drafting head.
|
||||
Need to be customized for each attention type.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
be a multiple of this value (in tokens). By default, it should
|
||||
be set to the block_size.
|
||||
dcp_world_size: The world size of decode context parallelism.
|
||||
pcp_world_size: The world size of prefill context parallelism.
|
||||
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block
|
||||
for each kv cache group in `kv_cache_group_ids`.
|
||||
Return a list of length `len(kv_cache_group_ids)`, where the i-th
|
||||
element is a list of cached blocks for the i-th kv cache group
|
||||
in `kv_cache_group_ids`.
|
||||
For example, sliding window manager should return a list like
|
||||
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
|
||||
and sliding window 8 and len(kv_cache_group_ids) = 1.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_skipped_blocks(self, request_id: str, num_computed_tokens: int) -> None:
|
||||
"""
|
||||
Remove and free the blocks that are no longer needed for attention computation.
|
||||
The removed blocks should be replaced by null_block.
|
||||
|
||||
This function depends on `get_num_skipped_tokens`, which need to be implemented
|
||||
differently for each attention type.
|
||||
|
||||
Args:
|
||||
request_id: The request ID.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
"""
|
||||
# Remove the blocks that will be skipped during attention computation.
|
||||
num_skipped_tokens = self.get_num_skipped_tokens(num_computed_tokens)
|
||||
if num_skipped_tokens <= 0:
|
||||
# This indicates that ALL tokens are inside attention window.
|
||||
# Thus we do not need to free any blocks outside attention window.
|
||||
# A typical case is full attention that we never free any token
|
||||
# before the request is finished.
|
||||
return
|
||||
num_skipped_blocks = num_skipped_tokens // self.block_size
|
||||
blocks = self.req_to_blocks[request_id]
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
# Because the block starts from index 0, the num_skipped_block-th block
|
||||
# corresponds to index num_skipped_blocks - 1.
|
||||
for i in range(num_skipped_blocks - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
self.block_pool.free_blocks(removed_blocks)
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
# The default behavior is to not skip any tokens.
|
||||
return 0
|
||||
|
||||
|
||||
class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, FullAttentionSpec | ChunkedLocalAttentionSpec
|
||||
), (
|
||||
"FullAttentionManager can only be used for full attention "
|
||||
"and chunked local attention groups"
|
||||
)
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
block_size = kv_cache_spec.block_size
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
block_size *= dcp_world_size * pcp_world_size
|
||||
max_num_blocks = max_length // block_size
|
||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids
|
||||
):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
if use_eagle and computed_blocks[0]:
|
||||
# Need to drop the last matched block if eagle is enabled.
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
while (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||
):
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
blocks = self.req_to_blocks[running_request_id]
|
||||
num_common_blocks = 0
|
||||
for block in blocks:
|
||||
if block.ref_cnt == len(self.req_to_blocks):
|
||||
num_common_blocks += 1
|
||||
else:
|
||||
break
|
||||
return num_common_blocks
|
||||
|
||||
|
||||
class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
def __init__(
|
||||
self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool, **kwargs
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||
assert pcp_world_size == 1, "PCP not support sliding window attn now."
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
sliding_window_contiguous_blocks = cdiv(
|
||||
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size
|
||||
)
|
||||
if use_eagle:
|
||||
# Need to drop the last matched block if eagle is enabled. For
|
||||
# sliding window layer, we achieve this by increasing the number of
|
||||
# contiguous blocks needed for prefix cache hit by one and dropping
|
||||
# the last matched block.
|
||||
sliding_window_contiguous_blocks += 1
|
||||
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(max_num_blocks) to
|
||||
# O(max_num_blocks / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
computed_blocks = tuple(
|
||||
[block_pool.null_block] * max_num_blocks
|
||||
for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
block_size = kv_cache_spec.block_size
|
||||
num_contiguous_blocks = 0
|
||||
match_found = False
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
# Skip prefix matching check if the block is not aligned with
|
||||
# `alignment_tokens`.
|
||||
if (
|
||||
num_contiguous_blocks == 0
|
||||
and block_size != alignment_tokens # Faster for common case.
|
||||
and (i + 1) * block_size % alignment_tokens != 0
|
||||
):
|
||||
continue
|
||||
# Add the cached block to the computed blocks.
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed[i] = cached
|
||||
num_contiguous_blocks += 1
|
||||
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
for computed in computed_blocks:
|
||||
del computed[i + num_contiguous_blocks :]
|
||||
match_found = True
|
||||
break
|
||||
else:
|
||||
num_contiguous_blocks = 0
|
||||
if not match_found:
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
for computed in computed_blocks:
|
||||
del computed[num_contiguous_blocks:]
|
||||
while (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and len(computed_blocks[0]) * block_size % alignment_tokens != 0
|
||||
):
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
if use_eagle and computed_blocks[0]:
|
||||
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||
"aligned_length is not compatible with eagle now"
|
||||
)
|
||||
for computed in computed_blocks:
|
||||
computed.pop()
|
||||
return computed_blocks
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
For sliding window, this corresponds to the tokens that are prior to
|
||||
the current sliding window.
|
||||
|
||||
Example:
|
||||
sliding_window=4, num_computed_tokens=7
|
||||
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 ]
|
||||
| ---- computed -----|
|
||||
^ next token to be computed
|
||||
|-----------| sliding window for next token
|
||||
|--skipped---|
|
||||
|
||||
The current window contains tokens 4~7. Tokens 0~3 will be skipped for
|
||||
attention computation since they are outside the sliding window.
|
||||
Thus, get_num_skipped_tokens(7) == 4.
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
return num_computed_tokens - self.sliding_window + 1
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
|
||||
So it's not correct to count ref_cnt like FullAttentionManager. Return
|
||||
0 here for correctness. Need to support cascade attention + sliding
|
||||
window in the future.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
def __init__(
|
||||
self, kv_cache_spec: ChunkedLocalAttentionSpec, block_pool: BlockPool, **kwargs
|
||||
) -> None:
|
||||
super().__init__(kv_cache_spec, block_pool, **kwargs)
|
||||
self.attention_chunk_size = kv_cache_spec.attention_chunk_size
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
For chunked local attention, we need to find the longest cache hit
|
||||
prefix of the blocks that is not longer than `max_length`. The prefix
|
||||
should be a common prefix hit for all the kv cache groups in
|
||||
`kv_cache_group_ids`. If no cache hit is found, return an empty list.
|
||||
note we mark as computed if the whole block is outside of the local
|
||||
window, and set the block as null. Examples:
|
||||
|
||||
1. Attention chunk size of 8, block size of 4, max length of 15
|
||||
for next token at 15th (zero-indexed), 8th - 14th tokens are in
|
||||
the window(needs lookup), 0th - 7th are not in the window,
|
||||
so they are already marked as computed. We check the complete
|
||||
block3 (8th - 11th tokens), Assume block 3 is hit, we will return
|
||||
[null, null, block 3], otherwise, we return [null, null]
|
||||
|
||||
2. Attention chunk size of 8, block size of 4, max length of 16
|
||||
for next token at 16th (zero-indexed), 0th - 15th tokens are not
|
||||
in the window, so they are already marked as computed.
|
||||
we return 4 blocks[null, null, null, null]
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
max_length: The maximum length of the cache hit prefix.
|
||||
kv_cache_group_ids: The ids of the kv cache groups.
|
||||
block_pool: The block pool.
|
||||
kv_cache_spec: The kv cache spec.
|
||||
use_eagle: Whether to use eagle.
|
||||
dcp_world_size: The world size of decode context parallelism.
|
||||
pcp_world_size: The world size of prefill context parallelism.
|
||||
alignment_tokens: The returned cache hit length (in tokens) should
|
||||
be a multiple of this value (in tokens).
|
||||
|
||||
Returns:
|
||||
A list of cached blocks
|
||||
"""
|
||||
assert isinstance(kv_cache_spec, ChunkedLocalAttentionSpec), (
|
||||
"ChunkedLocalAttentionManager can only be used for "
|
||||
+ "chunked local attention groups"
|
||||
)
|
||||
assert use_eagle is False, (
|
||||
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
||||
assert kv_cache_spec.block_size == alignment_tokens, (
|
||||
"KV cache groups with different block sizes are not compatible with "
|
||||
"chunked local attention now"
|
||||
)
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (
|
||||
max_length
|
||||
// kv_cache_spec.attention_chunk_size
|
||||
* kv_cache_spec.attention_chunk_size
|
||||
)
|
||||
else:
|
||||
local_attention_start_idx = 0
|
||||
# we marked blocks out of window as computed
|
||||
# with null blocks, and blocks inside window based on cache lookup
|
||||
# result [null] [null] ... [null] [hit block 1 (1st block contain
|
||||
# last window)] [hit block 2] ... [hit block x]
|
||||
local_attention_start_block_idx = (
|
||||
local_attention_start_idx // kv_cache_spec.block_size
|
||||
)
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[block_pool.null_block] * local_attention_start_block_idx
|
||||
for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
for i in range(local_attention_start_block_idx, max_num_blocks):
|
||||
block_hash = block_hashes[i]
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hash, kv_cache_group_ids
|
||||
):
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
computed.append(cached)
|
||||
else:
|
||||
break
|
||||
return computed_blocks
|
||||
|
||||
def get_num_skipped_tokens(self, num_computed_tokens: int) -> int:
|
||||
"""
|
||||
Get the number of tokens that will be skipped for attention computation.
|
||||
|
||||
For chunked local attention, this corresponds to the tokens that are on
|
||||
the left side of the current chunk.
|
||||
|
||||
Example 1:
|
||||
chunk size = 8, num_computed_tokens = 13
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
| ----- computed ---------------|
|
||||
^^ next token to be computed
|
||||
|----------------| <-- attention window for
|
||||
next token
|
||||
|--- skipped -----|
|
||||
Output: get_num_skipped_tokens(13) == 8
|
||||
|
||||
Example 2:
|
||||
chunk size = 8, num_computed_tokens = 8
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
| --- computed ---|
|
||||
^ next token to be computed
|
||||
|--| <-- attention window for next token
|
||||
| --- skipped ----|
|
||||
Output: get_num_skipped_tokens(8) == 8
|
||||
|
||||
Example 3:
|
||||
chunk size = 8, num_computed_tokens = 7
|
||||
Tokens: [ 0 1 2 3 4 5 6 7 | 8 9 10 11 12 13 14 15 ] ...
|
||||
|---computed---|
|
||||
^ next token to be computed
|
||||
|-----------------| <-- attention window for next token
|
||||
no token should be skipped.
|
||||
Output: get_num_skipped_tokens(7) == 0
|
||||
|
||||
Args:
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
|
||||
Returns:
|
||||
The number of tokens that will be skipped for attention computation.
|
||||
"""
|
||||
num_skipped_tokens = (
|
||||
num_computed_tokens // self.attention_chunk_size
|
||||
) * self.attention_chunk_size
|
||||
return num_skipped_tokens
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by chunked local attention.
|
||||
"""
|
||||
return 0
|
||||
|
||||
|
||||
class MambaManager(SingleTypeKVCacheManager):
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, MambaSpec), (
|
||||
"MambaManager can only be used for mamba groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||
assert pcp_world_size == 1, "PCP not support mamba now."
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
|
||||
block_size = kv_cache_spec.block_size
|
||||
max_num_blocks = max_length // block_size
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(max_num_blocks - 1, -1, -1):
|
||||
if cached_block := block_pool.get_cached_block(
|
||||
block_hashes[i], kv_cache_group_ids
|
||||
):
|
||||
# When enable Mamba prefix caching, `block_size` will be aligned
|
||||
# across full attention layers and Mamba layers to ensure the
|
||||
# prefix hit length aligned at block
|
||||
if (
|
||||
block_size != alignment_tokens # Faster for common case.
|
||||
and (i + 1) * block_size % alignment_tokens != 0
|
||||
):
|
||||
continue
|
||||
for computed, cached in zip(computed_blocks, cached_block):
|
||||
# the hit length logic later assumes:
|
||||
# hit_length = len(hit_blocks_other_attn[0])
|
||||
# * self.other_block_size
|
||||
# so we insert dummy blocks at the beginning:
|
||||
computed.extend([block_pool.null_block] * i)
|
||||
computed.append(cached)
|
||||
break # we just need the last match - early stopping
|
||||
|
||||
return computed_blocks
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
"""
|
||||
cascade attention is not supported by mamba
|
||||
"""
|
||||
return 0
|
||||
|
||||
def get_num_blocks_to_allocate(
|
||||
self,
|
||||
request_id: str,
|
||||
num_tokens: int,
|
||||
new_computed_blocks: Sequence[KVCacheBlock],
|
||||
) -> int:
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||
num_tokens += (
|
||||
self.kv_cache_spec.block_size
|
||||
* self.kv_cache_spec.num_speculative_blocks
|
||||
)
|
||||
return super().get_num_blocks_to_allocate(
|
||||
request_id, num_tokens, new_computed_blocks
|
||||
)
|
||||
|
||||
def allocate_new_blocks(
|
||||
self, request_id: str, num_tokens: int
|
||||
) -> list[KVCacheBlock]:
|
||||
# Allocate extra `num_speculative_blocks` blocks for
|
||||
# speculative decoding (MTP/EAGLE) with linear attention.
|
||||
assert isinstance(self.kv_cache_spec, MambaSpec)
|
||||
if self.kv_cache_spec.num_speculative_blocks > 0:
|
||||
num_tokens += (
|
||||
self.kv_cache_spec.block_size
|
||||
* self.kv_cache_spec.num_speculative_blocks
|
||||
)
|
||||
return super().allocate_new_blocks(request_id, num_tokens)
|
||||
|
||||
|
||||
class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
"""Manager for cross-attention KV cache in encoder-decoder models."""
|
||||
|
||||
def save_new_computed_blocks(
|
||||
self, request_id: str, new_computed_blocks: Sequence[KVCacheBlock]
|
||||
) -> None:
|
||||
# We do not cache blocks for cross-attention to be shared between
|
||||
# requests, so `new_computed_blocks` should always be empty.
|
||||
assert len(new_computed_blocks) == 0
|
||||
|
||||
def cache_blocks(self, request: Request, num_tokens: int) -> None:
|
||||
# We do not cache blocks for cross-attention to be shared between
|
||||
# requests, so this method is not relevant.
|
||||
raise ValueError("Should not be called as prefix caching is disabled.")
|
||||
|
||||
def get_num_common_prefix_blocks(self, running_request_id: str) -> int:
|
||||
# Cross-attention blocks contain request-specific encoder states
|
||||
# and are not shared between different requests
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def find_longest_cache_hit(
|
||||
cls,
|
||||
block_hashes: BlockHashList,
|
||||
max_length: int,
|
||||
kv_cache_group_ids: list[int],
|
||||
block_pool: BlockPool,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
alignment_tokens: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||
"CrossAttentionManager can only be used for cross-attention groups"
|
||||
)
|
||||
# Cross-attention does not benefit from prefix caching since:
|
||||
# 1. Encoder states are unique per request (different audio/image
|
||||
# inputs)
|
||||
# 2. Encoder states are computed once per request, not incrementally
|
||||
# 3. No reusable prefix exists between different multimodal inputs
|
||||
# Return empty blocks to indicate no cache hits
|
||||
raise NotImplementedError("CrossAttentionManager does not support caching")
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
MLAAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager,
|
||||
MambaSpec: MambaManager,
|
||||
CrossAttentionSpec: CrossAttentionManager,
|
||||
}
|
||||
|
||||
|
||||
def get_manager_for_kv_cache_spec(
|
||||
kv_cache_spec: KVCacheSpec, **kwargs
|
||||
) -> SingleTypeKVCacheManager:
|
||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||
manager = manager_class(kv_cache_spec, **kwargs)
|
||||
return manager
|
||||
Reference in New Issue
Block a user