[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

0
vllm/v1/core/__init__.py Normal file
View File

349
vllm/v1/core/block_pool.py Normal file
View File

@@ -0,0 +1,349 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from collections.abc import Iterable
from typing import Callable, Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHash, BlockHashWithGroupId,
FreeKVCacheBlockQueue, KVCacheBlock,
generate_block_hash_extra_keys,
hash_block_tokens)
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockPool:
"""BlockPool that manages KVCacheBlocks.
It provides methods to allocate, free and cache the kv cache blocks. The
free_block_queue stores the free blocks in eviction order to enable
allocation, free, and cache eviction. The cached_block_hash_to_block
maps between block hash and cached block to support finding cached blocks
by their block hash.
Args:
num_gpu_blocks: The number of blocks in the pool.
enable_caching: Whether to enable prefix caching.
enable_kv_cache_events: Whether to enable kv cache events.
"""
def __init__(
self,
num_gpu_blocks: int,
enable_caching: bool,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
# All kv-cache blocks.
self.blocks: list[KVCacheBlock] = [
KVCacheBlock(idx) for idx in range(num_gpu_blocks)
]
# Free block queue that constructs and manipulates a doubly linked
# list of free blocks (including eviction candidates when caching is
# enabled).
self.free_block_queue = FreeKVCacheBlockQueue(self.blocks)
# {block_hash: {block ID: block}}. A cached block is
# a full block with a block hash that can be used for prefix caching.
# The cached block may be used by running requests or in the
# free_block_queue that could potentially be evicted.
# NOTE: We currently don't de-duplicate the blocks in the cache,
# meaning that if a block becomes full and is cached, we don't check
# if there is already an identical block in the cache. This is because
# we want to make sure the allocated block IDs won't change so that
# block tables are append-only.
self.cached_block_hash_to_block: dict[BlockHashWithGroupId, dict[
int, KVCacheBlock]] = defaultdict(dict)
# To represent a placeholder block with block_id=0.
# The ref_cnt of null_block is not maintained, needs special care to
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
self.null_block.is_null = True
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
def get_cached_block(
self, block_hash: BlockHash,
kv_cache_group_ids: list[int]) -> Optional[list[KVCacheBlock]]:
"""Get the cached block by the block hash for each group in
`kv_cache_group_ids`, or None if cache miss for any group.
If there are duplicated blocks, we return the first block in the cache.
Args:
block_hash: The hash value of the block.
kv_cache_group_ids: The ids of the KV cache groups.
Returns:
The cached blocks if exists, or None.
"""
cached_blocks = []
for group_id in kv_cache_group_ids:
cached_blocks_one_group = self.cached_block_hash_to_block.get(
BlockHashWithGroupId(block_hash, group_id))
if not cached_blocks_one_group:
return None
first_block = next(iter(cached_blocks_one_group.values()))
cached_blocks.append(first_block)
return cached_blocks
def cache_full_blocks(
self,
request: Request,
blocks: list[KVCacheBlock],
block_hashes: list[BlockHash],
num_cached_blocks: int,
num_full_blocks: int,
block_size: int,
kv_cache_group_id: int,
hash_fn: Callable,
) -> None:
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
block hashes for the blocks starting from `num_cached_blocks` to
`num_full_blocks`, updating the metadata for each block
and caching them in the `cached_block_hash_to_block`.
Args:
request: The request to cache the blocks.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
be cached after this function.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
if num_cached_blocks == num_full_blocks:
return
new_full_blocks = blocks[num_cached_blocks:num_full_blocks]
assert len(block_hashes) >= num_cached_blocks
new_block_hashes = block_hashes[num_cached_blocks:]
# Update the new blocks with the block hashes through the chain.
if num_cached_blocks == 0:
prev_block_hash_value = None
else:
prev_block = blocks[num_cached_blocks - 1]
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.get_hash_value()
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
if i < len(new_block_hashes):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash = new_block_hashes[i]
else:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx = num_cached_blocks + i
start_token_idx = blk_idx * block_size
end_token_idx = (blk_idx + 1) * block_size
block_tokens = request.all_token_ids[
start_token_idx:end_token_idx]
assert len(block_tokens) == block_size, (
f"Expected {block_size} tokens, got "
f"{len(block_tokens)} at {blk_idx}th block for request "
f"{request.request_id}({request})")
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys, _ = generate_block_hash_extra_keys(
request, start_token_idx, end_token_idx, -1)
# Compute the hash of the current block.
block_hash = hash_block_tokens(hash_fn, prev_block_hash_value,
block_tokens, extra_keys)
block_hashes.append(block_hash)
# Update and added the full block to the cache.
block_hash_with_group_id = BlockHashWithGroupId(
block_hash, kv_cache_group_id)
blk.block_hash = block_hash_with_group_id
self.cached_block_hash_to_block[block_hash_with_group_id][
blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
if self.enable_kv_cache_events:
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
token_ids=request.
all_token_ids[num_cached_blocks *
block_size:num_full_blocks * block_size],
block_size=block_size,
lora_id=request.lora_request.id
if request.lora_request else None,
))
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.
Note that we do not check block cache in this function.
Args:
num_blocks: The number of blocks to allocate.
Returns:
A list of new block.
"""
if num_blocks > self.get_num_free_blocks():
raise ValueError(
f"Cannot get {num_blocks} free blocks from the pool")
ret: list[KVCacheBlock] = []
idx = 0
while idx < num_blocks:
# First allocate blocks.
curr_block = self.free_block_queue.popleft()
assert curr_block.ref_cnt == 0
# If the block is cached, evict it.
if self.enable_caching:
self._maybe_evict_cached_block(curr_block)
curr_block.incr_ref()
ret.append(curr_block)
idx += 1
return ret
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
"""
If a block is cached in `cached_block_hash_to_block`, we reset its hash
metadata and evict it from the cache.
Args:
block: The block to evict.
Returns:
True if the block is evicted, False otherwise.
"""
block_hash = block.block_hash
if block_hash and block_hash in self.cached_block_hash_to_block:
block.reset_hash()
del self.cached_block_hash_to_block[block_hash][block.block_id]
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
if self.enable_kv_cache_events:
# FIXME (Chen): Not sure whether we should return `hash_value`
# or `(hash_value, group_id)` here. But it's fine now because
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.get_hash_value()]))
return True
return False
def touch(self, blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""Touch a block increases its reference count by 1, and may remove
the block from the free queue. This is used when a block is hit by
another request with the same prefix.
Args:
blocks: A list of blocks to touch.
"""
for blocks_per_group in blocks:
for block in blocks_per_group:
# ref_cnt=0 means this block is in the free list (i.e. eviction
# candidate), so remove it.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.remove(block)
block.incr_ref()
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
"""Free a list of blocks. The blocks should be ordered by their
eviction priority, where the first block will be evicted first.
Args:
ordered_blocks: A list of blocks to free ordered by their eviction
priority.
"""
for block in ordered_blocks:
block.decr_ref()
# null_block should not be added to the free list.
if block.ref_cnt == 0 and not block.is_null:
self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = self.num_gpu_blocks - self.get_num_free_blocks()
if num_used_blocks != 1: # The null block is always marked as used
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks - 1)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.blocks:
block.reset_hash()
logger.info("Successfully reset prefix cache")
if self.enable_kv_cache_events:
self.kv_event_queue.append(AllBlocksCleared())
return True
def get_num_free_blocks(self) -> int:
"""Get the number of free blocks in the pool.
Returns:
The number of free blocks.
"""
return self.free_block_queue.num_free_blocks
def get_usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
def take_events(self) -> list[KVCacheEvent]:
"""Atomically takes all events and clears the queue.
Returns:
A list of KV cache events.
"""
if not self.enable_kv_cache_events:
return []
events = self.kv_event_queue
self.kv_event_queue = []
return events

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
from vllm.logger import init_logger
from vllm.multimodal import MultiModalRegistry
from vllm.v1.request import Request
if TYPE_CHECKING:
from vllm.config import ModelConfig, SchedulerConfig
logger = init_logger(__name__)
class EncoderCacheManager:
def __init__(self, cache_size: int):
self.cache_size = cache_size
self.num_free_slots = cache_size
# req_id -> cached input ids
self.cached: dict[str, set[int]] = {}
# list of [req_id, input_id]
self.freed: list[tuple[str, int]] = []
def has_cache(self, request: Request, input_id: int) -> bool:
req_id = request.request_id
return req_id in self.cached and input_id in self.cached[req_id]
def can_allocate(self, request: Request, input_id: int) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
return num_tokens <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
req_id = request.request_id
if req_id not in self.cached:
self.cached[req_id] = set()
self.cached[req_id].add(input_id)
self.num_free_slots -= request.get_num_encoder_tokens(input_id)
def get_cached_input_ids(self, request: Request) -> set[int]:
return self.cached.get(request.request_id, set())
def free_encoder_input(self, request: Request, input_id: int) -> None:
"""Free a single encoder input id for the request."""
req_id = request.request_id
if req_id not in self.cached:
return
self.cached[req_id].discard(input_id)
if len(self.cached[req_id]) == 0:
del self.cached[req_id]
self.num_free_slots += request.get_num_encoder_tokens(input_id)
self.freed.append((req_id, input_id))
def free(self, request: Request) -> None:
"""Free all cached input ids for the request."""
input_ids = self.get_cached_input_ids(request).copy()
for input_id in input_ids:
self.free_encoder_input(request, input_id)
def get_freed_ids(self) -> list[tuple[str, int]]:
freed = self.freed
self.freed = []
return freed
def compute_encoder_budget(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
if not model_config.is_multimodal_model:
return 0, 0
# TODO: handle encoder-decoder models once we support them.
(
encoder_compute_budget,
encoder_cache_size,
) = _compute_encoder_budget_multimodal(
model_config,
scheduler_config,
mm_registry,
)
return encoder_compute_budget, encoder_cache_size
def _compute_encoder_budget_multimodal(
model_config: "ModelConfig",
scheduler_config: "SchedulerConfig",
mm_registry: MultiModalRegistry,
) -> tuple[int, int]:
"""Compute the encoder cache budget based on the model and scheduler
configurations for a multimodal model.
Args:
model_config: Model configuration.
scheduler_config: Scheduler configuration.
mm_registry: Provides information about the token cost.
Returns:
- Compute budget for encoder execution, in unit of number of tokens
in the input sequence.
- Space budget for encoder cache size, in unit of number of tokens
in the input sequence.
"""
max_tokens_by_modality_dict = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
if not max_tokens_by_modality_dict:
logger.warning(
"All non-text modalities supported by the model have been "
"explicitly disabled via limit_mm_per_prompt. Encoder cache will "
"not be initialized.")
return 0, 0
_, max_tokens_per_mm_item = max(max_tokens_by_modality_dict.items(),
key=lambda item: item[1])
if (scheduler_config.disable_chunked_mm_input and max_tokens_per_mm_item
> scheduler_config.max_num_batched_tokens):
raise ValueError(
"Chunked MM input disabled but max_tokens_per_mm_item "
f"({max_tokens_per_mm_item}) is larger than max_num_batched_tokens"
f" ({scheduler_config.max_num_batched_tokens}). Please increase "
"max_num_batched_tokens.")
encoder_compute_budget = max(scheduler_config.max_num_encoder_input_tokens,
max_tokens_per_mm_item)
encoder_cache_size = max(scheduler_config.encoder_cache_size,
max_tokens_per_mm_item)
return encoder_compute_budget, encoder_cache_size

View File

@@ -0,0 +1,363 @@
# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from typing import Callable, Optional
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.core.single_type_kv_cache_manager import (
FullAttentionManager, get_manager_for_kv_cache_spec)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheConfig
from vllm.v1.request import Request
class KVCacheCoordinator(ABC):
"""
Coordinate the KV cache of different KV cache groups.
"""
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
use_eagle: bool,
enable_caching: bool,
caching_hash_fn: Callable,
enable_kv_cache_events: bool,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
self.block_pool = BlockPool(kv_cache_config.num_blocks, enable_caching,
enable_kv_cache_events)
# Needs special handling for find_longest_cache_hit if eagle is enabled
self.use_eagle = use_eagle
self.single_type_managers = tuple(
get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec,
block_pool=self.block_pool,
kv_cache_group_id=i,
caching_hash_fn=caching_hash_fn,
) for i, kv_cache_group in enumerate(
self.kv_cache_config.kv_cache_groups))
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_blocks_to_allocate = 0
for i, manager in enumerate(self.single_type_managers):
num_blocks_to_allocate += manager.get_num_blocks_to_allocate(
request_id, num_tokens, new_computed_blocks[i])
return num_blocks_to_allocate
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: tuple[list[KVCacheBlock], ...]) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
for i, manager in enumerate(self.single_type_managers):
manager.save_new_computed_blocks(request_id,
new_computed_blocks[i])
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> tuple[list[KVCacheBlock], ...]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
return tuple(
manager.allocate_new_blocks(request_id, num_tokens)
for manager in self.single_type_managers)
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_computed_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
for manager in self.single_type_managers:
manager.cache_blocks(request, block_hashes, num_computed_tokens)
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
for manager in self.single_type_managers:
manager.free(request_id)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> list[int]:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
num_blocks_per_group = [
manager.get_num_common_prefix_blocks(request_id,
num_running_requests)
for manager in self.single_type_managers
]
return num_blocks_per_group
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and replace
the removed blocks with null_block.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
"""
return tuple(
manager.req_to_blocks.get(request_id) or []
for manager in self.single_type_managers)
@abstractmethod
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
pass
class UnitaryKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for models with only one KV cache group. This is the
case for models with only one KV cache type, e.g., all attention layers use
full attention or all attention layers use sliding window attention.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[
0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
hit_blocks = self.single_type_managers[0].find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=[0],
block_pool=self.block_pool,
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
class HybridKVCacheCoordinator(KVCacheCoordinator):
"""
KV cache coordinator for hybrid models with multiple KV cache types, and
thus multiple kv cache groups.
To simplify `find_longest_cache_hit`, it only supports the combination of
two types of KV cache groups, and one of them must be full attention.
May extend to more general cases in the future.
"""
def __init__(self, kv_cache_config: KVCacheConfig, max_model_len: int,
use_eagle: bool, enable_caching: bool,
caching_hash_fn: Callable, enable_kv_cache_events: bool):
super().__init__(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
"""
Verifies that the model has exactly two types of KV cache groups, and
one of them is full attention. Then, split the kv cache groups into full
attention groups and other groups.
"""
full_attention_type_id: Optional[str] = None
other_type_id: Optional[str] = None
self.full_attention_group_ids: list[int] = []
self.other_group_ids: list[int] = []
for i, g in enumerate(self.kv_cache_config.kv_cache_groups):
if isinstance(g.kv_cache_spec, FullAttentionSpec):
if full_attention_type_id is None:
full_attention_type_id = g.kv_cache_spec.type_id
else:
assert full_attention_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes exactly one type of "
"full attention groups now.")
self.full_attention_group_ids.append(i)
else:
if other_type_id is None:
other_type_id = g.kv_cache_spec.type_id
else:
assert other_type_id == g.kv_cache_spec.type_id, (
"HybridKVCacheCoordinator assumes "
"exactly one other type of groups now.")
self.other_group_ids.append(i)
assert full_attention_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of full "
"attention groups now.")
assert other_type_id is not None, (
"HybridKVCacheCoordinator assumes exactly one type of other "
"groups now.")
self.full_attention_manager_cls = FullAttentionManager
self.other_attention_cls = self.single_type_managers[
self.other_group_ids[0]].__class__
self.full_attention_spec = self.kv_cache_config.kv_cache_groups[
self.full_attention_group_ids[0]].kv_cache_spec
self.other_spec = self.kv_cache_config.kv_cache_groups[
self.other_group_ids[0]].kv_cache_spec
self.full_attention_block_size = self.full_attention_spec.block_size
self.other_block_size = self.other_spec.block_size
assert self.other_block_size % self.full_attention_block_size == 0, (
"KVCacheCoordinator assumes the block_size of full attention "
"layers is divisible by other layers now.")
if max(self.full_attention_group_ids) < min(self.other_group_ids):
self.full_attn_first = True
elif max(self.other_group_ids) < min(self.full_attention_group_ids):
self.full_attn_first = False
else:
raise ValueError(
"HybridKVCacheCoordinator assumes the full "
"attention group ids and other attention group ids "
"do not interleave, either full attention group ids "
"are before other attention group ids or vice versa."
"This is for simplifying merging hit_blocks_full_attn and "
"hit_blocks_other_attn to hit_blocks.")
def find_longest_cache_hit(
self,
block_hashes: list[BlockHash],
max_cache_hit_length: int,
) -> tuple[tuple[list[KVCacheBlock], ...], int]:
"""
Find the longest cache hit for the request.
Args:
block_hashes: The block hashes of the request.
max_cache_hit_length: The maximum length of the cache hit.
Returns:
A tuple containing:
- A list of the cache hit blocks for each single type manager.
- The number of tokens of the longest cache hit.
"""
# First, find the longest cache hit for full attention.
hit_blocks_full_attn = (
self.full_attention_manager_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=max_cache_hit_length,
kv_cache_group_ids=self.full_attention_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.full_attention_spec,
use_eagle=self.use_eagle,
))
hit_length = len(
hit_blocks_full_attn[0]) * self.full_attention_block_size
# Next, find the cache hit for the other attention WITHIN
# the cache hit of full attention.
hit_blocks_other_attn = (
self.other_attention_cls.find_longest_cache_hit(
block_hashes=block_hashes,
max_length=hit_length,
kv_cache_group_ids=self.other_group_ids,
block_pool=self.block_pool,
kv_cache_spec=self.other_spec,
use_eagle=self.use_eagle,
))
hit_length = len(hit_blocks_other_attn[0]) * self.other_block_size
# NOTE: the prefix cache hit length must be a multiple of block_size as
# we don't support partial block cache hit yet. The cache hit length
# of other attention is ensured to be a multiple of the block size of
# full attention layers in current implementation, because hit_length is
# a multiple of other attention's block size, and other attention's
# block size is a multiple of full attention's block size (verified in
# `verify_and_split_kv_cache_groups`).
assert hit_length % self.full_attention_block_size == 0
# Truncate the full attention cache hit to the length of the
# cache hit of the other attention.
for group_hit_blocks in hit_blocks_full_attn:
del group_hit_blocks[hit_length // self.full_attention_block_size:]
# Merge the hit blocks of full attention and other attention.
if self.full_attn_first:
hit_blocks = hit_blocks_full_attn + hit_blocks_other_attn
else:
hit_blocks = hit_blocks_other_attn + hit_blocks_full_attn
return hit_blocks, hit_length
def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, max_model_len: int, use_eagle: bool,
enable_caching: bool, caching_hash_fn: Callable,
enable_kv_cache_events: bool) -> KVCacheCoordinator:
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(kv_cache_config, max_model_len,
use_eagle, enable_caching,
caching_hash_fn,
enable_kv_cache_events)
return HybridKVCacheCoordinator(kv_cache_config, max_model_len, use_eagle,
enable_caching, caching_hash_fn,
enable_kv_cache_events)

View File

@@ -0,0 +1,392 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import sha256
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
hash_request_tokens)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request, RequestStatus
logger = init_logger(__name__)
@dataclass
class KVCacheBlocks:
"""
The allocation result of KVCacheManager, work as the interface between
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
structure from the Scheduler.
"""
blocks: tuple[list[KVCacheBlock], ...]
"""
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
We don't use block of tokens as the outer dimension because it assumes all
kv_cache_groups have the same number of blocks, which is true for now but
will be broken if we want to give different block_size to different
kv_cache_groups in the future.
"""
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
"""Adds two KVCacheBlocks instances."""
return KVCacheBlocks(
tuple(blk1 + blk2
for blk1, blk2 in zip(self.blocks, other.blocks)))
def get_block_ids(self) -> tuple[list[int], ...]:
"""
Converts the KVCacheBlocks instance to block_ids.
Returns:
tuple[list[int], ...]: A tuple of lists where
* the outer tuple corresponds to KV cache groups
* each inner list contains the block_ids of the blocks in that group
"""
return tuple([blk.block_id for blk in group] for group in self.blocks)
def get_unhashed_block_ids(self) -> list[int]:
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
assert len(self.blocks) == 1, "Only one group is supported"
return [
block.block_id for block in self.blocks[0]
if block.block_hash is None
]
def new_empty(self) -> "KVCacheBlocks":
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
class KVCacheManager:
def __init__(
self,
kv_cache_config: KVCacheConfig,
max_model_len: int,
enable_caching: bool = True,
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
self.max_model_len = max_model_len
self.enable_caching = enable_caching
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
self.use_eagle = use_eagle
self.log_stats = log_stats
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
assert len(
set(g.kv_cache_spec.block_size
for g in kv_cache_config.kv_cache_groups)
) == 1, "Only one block size is supported for now"
self.block_size = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.block_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len,
use_eagle=self.use_eagle,
enable_caching=enable_caching,
caching_hash_fn=self.caching_hash_fn,
enable_kv_cache_events=enable_kv_cache_events,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool
self.kv_cache_config = kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self.req_to_block_hashes: defaultdict[
str, list[BlockHash]] = defaultdict(list)
@property
def usage(self) -> float:
"""Get the KV cache usage.
Returns:
The KV cache usage (between 0.0 and 1.0).
"""
return self.block_pool.get_usage()
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
"""Get (and reset) the prefix cache stats.
Returns:
The current prefix caching stats, or None if logging is disabled.
"""
if not self.log_stats:
return None
stats = self.prefix_cache_stats
self.prefix_cache_stats = PrefixCacheStats()
return stats
def get_computed_blocks(self,
request: Request) -> tuple[KVCacheBlocks, int]:
"""Get the computed (cached) blocks for the request.
Note that the computed blocks must be full.
Args:
request: The request to get the computed blocks.
Returns:
A tuple containing:
- A list of blocks that are computed for the request.
- The number of computed tokens.
"""
# Prefix caching is disabled or
# When the request requires prompt logprobs, we skip prefix caching.
if (not self.enable_caching
or request.sampling_params.prompt_logprobs is not None):
return self.create_empty_block_list(), 0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes = self.req_to_block_hashes[request.request_id]
if not block_hashes:
block_hashes = hash_request_tokens(self.caching_hash_fn,
self.block_size, request)
self.req_to_block_hashes[request.request_id] = block_hashes
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.requests += 1
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
# the single last token, because allocate_slots() requires
# num_computed_tokens to be block-size aligned. Removing this limitation
# could slightly improve performance in the future.
max_cache_hit_length = request.num_tokens - 1
computed_blocks, num_new_computed_tokens = (
self.coordinator.find_longest_cache_hit(block_hashes,
max_cache_hit_length))
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.queries += request.num_tokens
self.prefix_cache_stats.hits += num_new_computed_tokens
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
def allocate_slots(
self,
request: Request,
num_new_tokens: int,
num_new_computed_tokens: int = 0,
new_computed_blocks: Optional[KVCacheBlocks] = None,
num_draft_tokens: int = 0,
num_lookahead_tokens: int = 0,
delay_cache_blocks: bool = False,
) -> Optional[KVCacheBlocks]:
"""Add slots for a request with new tokens to append.
Args:
request: The request to allocate slots.
num_new_tokens: The number of tokens to allocate, including external
tokens. Note that this does not include tokens that have
already been computed locally (i.e. new_computed_blocks).
num_new_computed_tokens: The number of new computed tokens just
hitting the prefix caching, excluding external tokens.
new_computed_blocks: The cached blocks for the above new computed
tokens.
num_lookahead_tokens: The number of speculative tokens to allocate.
This is used by spec decode proposers with kv-cache such
as eagle.
delay_cache_blocks: Whether to skip caching the blocks. This is
used by P/D when allocating blocks used in a KV transfer
which will complete in a future step.
Blocks layout:
```
-----------------------------------------------------------------------
| < computed > | < new computed > | < new > | < pre-allocated > |
-----------------------------------------------------------------------
| < required > |
--------------------------------------------------
| < full > |
------------------------------------------------
| <new full> |
--------------
```
The following *_blocks are illustrated in this layout.
Returns:
A list of new allocated blocks.
"""
if num_new_tokens == 0:
raise ValueError("num_new_tokens must be greater than 0")
if new_computed_blocks is not None:
new_computed_block_list = new_computed_blocks.blocks
else:
new_computed_block_list = tuple(
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
# Free the blocks that are skipped during the attention computation
# (e.g., tokens outside the sliding window).
# We can do this even if we cannot schedule this request due to
# insufficient free blocks.
# Should call this function before allocating new blocks to reduce
# the number of evicted blocks.
self.coordinator.remove_skipped_blocks(request.request_id,
request.num_computed_tokens)
# The number of computed tokens is the number of computed tokens plus
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
num_new_computed_tokens)
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
request_id=request.request_id,
num_tokens=num_tokens_need_slot,
new_computed_blocks=new_computed_block_list,
)
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
# Cannot allocate new blocks
return None
# Touch the computed blocks to make sure they won't be evicted.
if self.enable_caching:
self.block_pool.touch(new_computed_block_list)
else:
assert not any(new_computed_block_list), (
"Computed blocks should be empty when "
"prefix caching is disabled")
# Append the new computed blocks to the request blocks until now to
# avoid the case where the new blocks cannot be allocated.
self.coordinator.save_new_computed_blocks(request.request_id,
new_computed_block_list)
new_blocks = self.coordinator.allocate_new_blocks(
request.request_id, num_tokens_need_slot)
# P/D: delay caching blocks if we have to recv from
# remote. Update state for locally cached blocks.
if not self.enable_caching or delay_cache_blocks:
return KVCacheBlocks(new_blocks)
# Speculated tokens might be rejected in the future, so we does
# not cache any speculated tokens. We only cache blocks with
# generated (accepted) tokens.
self.coordinator.cache_blocks(
request, self.req_to_block_hashes[request.request_id],
num_computed_tokens + num_new_tokens - num_draft_tokens)
return KVCacheBlocks(new_blocks)
def free(self, request: Request) -> None:
"""Free the blocks allocated for the request.
We free the blocks in reverse order so that he tail blocks are evicted
first when caching is enabled.
Args:
request: The request to free the blocks.
"""
self.coordinator.free(request.request_id)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalidate prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
if not self.block_pool.reset_prefix_cache():
return False
if self.log_stats:
assert self.prefix_cache_stats is not None
self.prefix_cache_stats.reset = True
return True
def get_num_common_prefix_blocks(
self,
request: Request,
num_running_requests: int,
) -> list[int]:
"""Calculate the number of common prefix blocks shared by all requests
in the RUNNING state for each kv cache group.
The function determines this by selecting any request and iterating
through its blocks. A block is considered a common prefix block if its
`ref_cnt` equals the total number of requests in the RUNNING state.
NOTE(woosuk): The number of requests in the RUNNING state is **greater
than or equal to** the number of requests scheduled in the current step.
This is because the RUNNING state only indicates that:
1. The request has not yet finished, and
2. The request holds its blocks unfreed.
While all scheduled requests must be in the RUNNING state, the inverse
is not necessarily true. There may be RUNNING requests that are not
scheduled in the current step.
This can result in an edge case where the number of common prefix blocks
is 0, even though all scheduled requests share a common prefix. This
occurs because there may be unscheduled RUNNING requests that do not
share the common prefix. Currently, this case cannot be easily detected,
so the function returns 0 in such cases.
Args:
request: Any request in the RUNNING state, used to identify the
common prefix blocks.
num_running_requests: The total number of requests in the RUNNING
state. This can be different from the number of scheduled
requests in the current step.
Returns:
list[int]: The number of common prefix blocks for each kv cache
group.
"""
assert request.status == RequestStatus.RUNNING
return self.coordinator.get_num_common_prefix_blocks(
request.request_id, num_running_requests)
def free_block_hashes(self, request: Request) -> None:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
Returns:
A list of KV cache events.
"""
return self.block_pool.take_events()
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
"""Get the block ids of a request."""
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request."""
block_hashes = self.req_to_block_hashes[request.request_id]
self.coordinator.cache_blocks(request, block_hashes,
num_computed_tokens)
def create_empty_block_list(self) -> KVCacheBlocks:
"""Creates a new KVCacheBlocks instance with no blocks."""
return KVCacheBlocks(tuple([]
for _ in range(self.num_kv_cache_groups)))

View File

@@ -0,0 +1,996 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""KV-Cache Utilities."""
import os
from collections import defaultdict, deque
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
from typing import Any, Callable, NamedTuple, Optional
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, cdiv, sha256
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec, KVCacheSpec,
KVCacheTensor, SlidingWindowSpec)
from vllm.v1.metrics.stats import PrefixCacheStats
from vllm.v1.request import Request
logger = init_logger(__name__)
class BlockHash(NamedTuple):
"""Hash value of a block (int), the token IDs in the block, and extra keys.
We keep a tuple of token IDs and extra keys to reduce the likelihood of
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
# Hash value of the block in an integer.
hash_value: int
# Token IDs in the block.
token_ids: tuple[int, ...]
# Extra keys for the block.
extra_keys: Optional[Any] = None
class BlockHashWithGroupId(NamedTuple):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash: BlockHash
# The KV cache group ID.
group_id: int
def get_hash_value(self) -> int:
return self.block_hash.hash_value
# The hash seed for the first block of the prefix block sequence.
#
# Even if the hash function is the builtin hash(), we use sha256 to generate
# the initial hash to simplify the code. This is not performance critical
# as it is done one per process.
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# This aligns with the behavior of Python's hash() function, which also uses
# a random seed if PYTHONHASHSEED is not set.
NONE_HASH = int.from_bytes(os.urandom(32), byteorder="big") if os.getenv(
"PYTHONHASHSEED") is None else sha256(os.getenv("PYTHONHASHSEED"))
class PrefixCachingMetrics:
"""Metrics for prefix caching with a hit rate of the max recent N requests.
Args:
max_recent_requests: The number of the max recent requests to aggregate.
Defaults to 1000.
"""
def __init__(self, max_recent_requests: int = 1000):
self.max_recent_requests = max_recent_requests
# The current aggregated values.
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
# A deque of (requests, queries, hits) for the most recent requests.
self.query_queue: deque[tuple[int, int, int]] = deque()
def observe(self, stats: PrefixCacheStats):
"""Observe the prefix caching for a set of requests.
This function is called with information gathered when new requests
are being scheduled and are looking for computed blocks.
When there are more than `interval` requests, the oldest set of
requests are removed from the metrics.
Args:
stats: The prefix cache stats.
"""
# reset_prefix_cache was invoked before the current update.
# Reset the metrics before aggregating the current stats.
if stats.reset:
self.reset()
# Update the metrics.
self.query_queue.append((stats.requests, stats.queries, stats.hits))
self.aggregated_requests += stats.requests
self.aggregated_query_total += stats.queries
self.aggregated_query_hit += stats.hits
# Remove the oldest stats if the number of requests exceeds.
if self.aggregated_requests > self.max_recent_requests:
old_requests, old_queries, old_hits = self.query_queue.popleft()
self.aggregated_requests -= old_requests
self.aggregated_query_total -= old_queries
self.aggregated_query_hit -= old_hits
def reset(self):
"""Reset the metrics."""
self.aggregated_requests = 0
self.aggregated_query_total = 0
self.aggregated_query_hit = 0
self.query_queue.clear()
@property
def hit_rate(self) -> float:
"""Calculate the hit rate for the past N requests."""
if self.aggregated_query_total == 0:
return 0.0
return self.aggregated_query_hit / self.aggregated_query_total
@dataclass
class KVCacheBlock:
"""KV-cache block metadata."""
# Block ID, ranging from 0 to num_gpu_blocks - 1.
block_id: int
# Reference count.
ref_cnt: int = 0
# The hash of the block composed of (block hash, tuple of token IDs).
# It is only available when the block is full.
_block_hash: Optional[BlockHashWithGroupId] = None
# Used to construct a doubly linked list for free blocks.
# These two attributes should only be manipulated by FreeKVCacheBlockQueue.
prev_free_block: Optional["KVCacheBlock"] = None
next_free_block: Optional["KVCacheBlock"] = None
# Whether the block is a null block that should never be cached.
is_null: bool = False
def incr_ref(self):
self.ref_cnt += 1
def decr_ref(self):
self.ref_cnt -= 1
@property
def block_hash(self) -> Optional[BlockHashWithGroupId]:
return self._block_hash
@block_hash.setter
def block_hash(self, block_hash: BlockHashWithGroupId):
assert self.block_hash is None, (
"The block already has a hash. This should not happen.")
self._block_hash = block_hash
def reset_hash(self):
"""Reset the block hash when the block is evicted."""
self._block_hash = None
def __repr__(self) -> str:
# Use block_id instead of KVCacheBlock object to avoid calling __repr__
# on KVCacheBlock object recursively.
prev_block_id = (self.prev_free_block.block_id
if self.prev_free_block else None)
next_block_id = (self.next_free_block.block_id
if self.next_free_block else None)
return (f"KVCacheBlock(block_id={self.block_id}, "
f"ref_cnt={self.ref_cnt}, "
f"_block_hash={self._block_hash}, "
f"prev_free_block={prev_block_id}, "
f"next_free_block={next_block_id})")
class FreeKVCacheBlockQueue:
"""This class organizes a list of KVCacheBlock objects to a doubly linked
list of free blocks. We implement this class instead of using Python
builtin deque to support removing a block in the middle of the queue
in O(1) time. To close the performance gap to the builtin deque which is
implemented in C++, this class does not allocate any Python objects when
manipulating the linked list. Instead, this class manipulates the
prev_free_block and next_free_block attributes of the given blocks.
The queue is ordered by block ID in the beginning. When a block is allocated
and then freed, it will be appended back with the eviction order:
1. The least recent used block is at the front (LRU).
2. If two blocks have the same last accessed time (allocated by the
same sequence), the one with more hash tokens (the tail of a block
chain) is at the front.
Note that we maintain this order by reversing the block order when free
blocks of a request. This operation is outside of this class.
Args:
blocks: A list of KVCacheBlock objects.
"""
def __init__(self, blocks: list[KVCacheBlock]) -> None:
self.num_free_blocks = len(blocks)
# Initialize the doubly linked list of free blocks.
self.free_list_head: Optional[KVCacheBlock] = blocks[0]
self.free_list_tail: Optional[KVCacheBlock] = blocks[-1]
for i in range(self.num_free_blocks):
if i > 0:
blocks[i].prev_free_block = blocks[i - 1]
if i < self.num_free_blocks - 1:
blocks[i].next_free_block = blocks[i + 1]
def popleft(self) -> KVCacheBlock:
"""Pop the first free block and reduce num_free_blocks by 1.
Returns:
The first free block.
"""
if not self.free_list_head:
raise ValueError("No free blocks available")
block = self.free_list_head
self.remove(block)
return block
def remove(self, block: KVCacheBlock) -> None:
"""Remove a block in the free list and reduce num_free_blocks by 1.
Args:
block: The block to remove.
"""
if block.prev_free_block is not None:
# Link the previous block to the next block.
block.prev_free_block.next_free_block = block.next_free_block
if block.next_free_block is not None:
# Link the next block to the previous block.
block.next_free_block.prev_free_block = block.prev_free_block
if block == self.free_list_head:
# Update the head if the block is the head.
self.free_list_head = block.next_free_block
if block == self.free_list_tail:
# Update the tail if the block is the tail.
self.free_list_tail = block.prev_free_block
# Remove the block from the linked list.
block.prev_free_block = block.next_free_block = None
self.num_free_blocks -= 1
def append(self, block: KVCacheBlock) -> None:
"""Put a block back into the free list and increase
num_free_blocks by 1.
Args:
block: The block to append.
"""
if self.free_list_tail is not None:
# Link the last block to the new block.
self.free_list_tail.next_free_block = block
block.prev_free_block = self.free_list_tail
self.free_list_tail = block
else:
# The free list is empty.
assert self.free_list_head is None
self.free_list_head = self.free_list_tail = block
block.next_free_block = None
self.num_free_blocks += 1
def get_all_free_blocks(self) -> list[KVCacheBlock]:
"""Get all free blocks in the free list. Mainly used for testing.
Returns:
A list of free blocks.
"""
ret = []
curr_block = self.free_list_head
while curr_block is not None:
ret.append(curr_block)
curr_block = curr_block.next_free_block
return ret
def need_extra_keys(request: Request) -> bool:
"""Check whether the blocks allocated to this request need extra hash keys.
Args:
request (Request): The request.
Returns:
bool: Whether blocks allocated to this request need extra hash keys.
"""
# Multimodal requests need to include the MM hash.
# LoRA requests need to include the LoRA ID.
# Request with provided cache salt need to include the salt.
return bool(request.mm_positions) or (request.lora_request
is not None) or (request.cache_salt
is not None)
def _gen_mm_extra_hash_keys(request: Request, start_token_idx: int,
end_token_idx: int,
start_mm_idx: int) -> tuple[list[Any], int]:
"""Generate extra keys related to MultiModal request for block hash
computation. For multi-modal inputs, the extra keys are
(mm_hash, start_offset) that indicate a mm input contained in the
block and its starting offset in the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
extra_keys: list[Any] = []
mm_positions, mm_hashes = request.mm_positions, request.mm_hashes
if not mm_positions:
return extra_keys, start_mm_idx
if mm_positions and len(mm_positions) != len(mm_hashes):
raise ValueError(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set disable_mm_preprocessor_cache=False.")
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if mm_positions[-1].offset + mm_positions[-1].length < start_token_idx:
return extra_keys, start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if start_mm_idx < 0:
assert -start_mm_idx <= len(mm_positions)
start_mm_idx = len(mm_positions) + start_mm_idx
curr_mm_idx = start_mm_idx
while mm_positions and curr_mm_idx < len(mm_positions):
assert mm_hashes[curr_mm_idx] is not None
offset = mm_positions[curr_mm_idx].offset
length = mm_positions[curr_mm_idx].length
if end_token_idx > offset:
if start_token_idx > offset + length:
# This block has passed the current mm input.
curr_mm_idx += 1
continue
# The block contains the current mm input.
extra_keys.append(mm_hashes[curr_mm_idx])
if end_token_idx >= offset + length:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx += 1
else:
# Otherwise this block is done with mm inputs.
break
else:
# This block has not reached the current mm input.
break
return extra_keys, curr_mm_idx
def _gen_lora_extra_hash_keys(request: Request) -> list[int]:
"""Generate extra keys related to LoRA for block hash computation.
Args:
request: The request object.
Returns:
Return LoRA id of the request if it is a LoRA request. Return empty
list otherwise.
"""
if not request.lora_request:
return []
return [request.lora_request.lora_int_id]
def generate_block_hash_extra_keys(
request: Request, start_token_idx: int, end_token_idx: int,
start_mm_idx: int) -> tuple[Optional[tuple[Any, ...]], int]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_extra_keys: list[Any]
mm_extra_keys, new_start_mm_idx = _gen_mm_extra_hash_keys(
request, start_token_idx, end_token_idx, start_mm_idx)
lora_extra_keys: list[int] = _gen_lora_extra_hash_keys(request)
cache_salt_keys: list[str] = [request.cache_salt] if (
start_token_idx == 0 and request.cache_salt) else []
extra_keys: list[Any] = lora_extra_keys + mm_extra_keys + cache_salt_keys
if not extra_keys:
return None, new_start_mm_idx
return tuple(extra_keys), new_start_mm_idx
def hash_block_tokens(
hash_function: Callable,
parent_block_hash: Optional[int],
curr_block_token_ids: Sequence[int],
extra_keys: Optional[tuple[Any, ...]] = None) -> BlockHash:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
Args:
parent_block_hash: The hash of the parent block. None
if this is the first block.
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
if not parent_block_hash:
parent_block_hash = NONE_HASH
curr_block_token_ids_tuple = tuple(curr_block_token_ids)
return BlockHash(
hash_function(
(parent_block_hash, curr_block_token_ids_tuple, extra_keys)),
curr_block_token_ids_tuple, extra_keys)
def hash_request_tokens(hash_function: Any, block_size: int,
request: Request) -> list[BlockHash]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
request: The request object.
Returns:
The list of computed hash values.
"""
token_ids = request.all_token_ids
req_need_extra_keys = need_extra_keys(request)
req_extra_keys = None
curr_mm_idx = 0
ret = []
parent_block_hash_value = None
for start in range(0, len(token_ids), block_size):
end = start + block_size
block_token_ids = token_ids[start:end]
# Do not hash the block if it is not full.
if len(block_token_ids) < block_size:
break
if req_need_extra_keys:
# MM and LoRA requests need extra keys for block-hash computation.
req_extra_keys, curr_mm_idx = generate_block_hash_extra_keys(
request, start, end, curr_mm_idx)
block_hash = hash_block_tokens(hash_function, parent_block_hash_value,
block_token_ids, req_extra_keys)
ret.append(block_hash)
parent_block_hash_value = block_hash.hash_value
return ret
def max_memory_usage_bytes(vllm_config: VllmConfig,
kv_cache_specs: Iterable[KVCacheSpec]) -> int:
"""
Get the maximum memory usage in bytes for the given KV cache specs.
"""
return sum(
spec.max_memory_usage_bytes(vllm_config) for spec in kv_cache_specs)
def estimate_max_model_len(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> int:
"""
Estimates the maximum model length that can fit in the available memory
using binary search.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The estimated maximum model length that can fit in the available memory.
"""
# Define a function to check if a given model length fits in memory
def fits_in_memory(model_len: int) -> bool:
# Modify the max_model_len for this calculation
vllm_config.model_config.max_model_len = model_len
# Calculate memory needed for the given model length
memory_needed = max_memory_usage_bytes(vllm_config,
kv_cache_spec.values())
return memory_needed <= available_memory
# Binary search for the maximum model length
current_max = vllm_config.model_config.max_model_len
left, right = 1, current_max
# If even the smallest model length doesn't fit, return 0
if not fits_in_memory(left):
return 0
# Binary search for the maximum model length that fits
result = 1
while left <= right:
mid = (left + right) // 2
if fits_in_memory(mid):
result = mid
left = mid + 1
else:
right = mid - 1
return result
def check_enough_kv_cache_memory(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int):
"""
Checks whether `available_memory` is enough for the KV cache to hold at
least one request with the model's max_model_len.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Raises:
ValueError: If there is not enough memory available for the KV cache.
"""
if available_memory <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_model_len = vllm_config.model_config.max_model_len
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
if needed_memory > available_memory:
# Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory)
estimated_msg = ""
if estimated_max_len > 0:
estimated_msg = (
"Based on the available memory, "
f"the estimated maximum model length is {estimated_max_len}.")
raise ValueError(
f"To serve at least one request with the models's max seq len "
f"({max_model_len}), ({needed_memory/GiB_bytes:.2f} GiB KV "
f"cache is needed, which is larger than the available KV cache "
f"memory ({available_memory/GiB_bytes:.2f} GiB). "
f"{estimated_msg} "
f"Try increasing `gpu_memory_utilization` or decreasing "
f"`max_model_len` when initializing the engine.")
def create_kv_cache_group_specs(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:
"""
Create KVCacheGroupSpec object for each kv cache group layer.
The layers in the same group should share the same
KVCacheSpec.
Args:
kv_cache_spec:
A mapping from each layer name to its corresponding KVCacheSpec.
grouped_layer_names:
A list of kv cache groups, where each element is a list of layer
names that belong to the same group and should share the same
KVCacheSpec.
Returns:
A list of KVCacheGroupSpec objects, one for each group.
"""
kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_specs = [
kv_cache_spec[layer_name] for layer_name in layer_names_one_group
]
merged_layer_spec = layer_specs[0].merge(layer_specs)
kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
return kv_cache_groups
def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same type of KV cache.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
Returns:
True if all layers have the same type, False otherwise.
"""
layer_keys = set(layer.type_id for layer in kv_cache_spec.values())
return len(layer_keys) == 1
def get_max_concurrency_for_kv_cache_config(
vllm_config: VllmConfig, kv_cache_config: KVCacheConfig) -> float:
"""
Get the maximum concurrency for the given KV cache configuration.
"""
num_layer_per_group = max(
len(group.layer_names) for group in kv_cache_config.kv_cache_groups)
max_memory_usage_per_request = num_layer_per_group * max_memory_usage_bytes(
vllm_config,
(group.kv_cache_spec for group in kv_cache_config.kv_cache_groups))
memory_per_block = kv_cache_config.kv_cache_groups[
0].kv_cache_spec.page_size_bytes * num_layer_per_group
num_block_per_request = cdiv(max_memory_usage_per_request,
memory_per_block)
max_concurrency = kv_cache_config.num_blocks / num_block_per_request
return max_concurrency
def get_num_blocks(vllm_config: VllmConfig, num_layers: int,
available_memory: int, page_size: int) -> int:
"""
Get the number of kv cache blocks.
Args:
vllm_config: The global VllmConfig
num_layers: The number of layers
available_memory: Memory available for KV cache in bytes.
page_size: The page size of the KV cache.
"""
num_blocks = int(available_memory // page_size // num_layers)
num_blocks = max(num_blocks, 0)
if vllm_config.cache_config.num_gpu_blocks_override is not None:
num_gpu_blocks_override = \
vllm_config.cache_config.num_gpu_blocks_override
logger.info(
"Overriding num_gpu_blocks=%d with "
"num_gpu_blocks_override=%d", num_blocks, num_gpu_blocks_override)
return num_blocks
def get_uniform_page_size(kv_cache_spec: dict[str, KVCacheSpec]) -> int:
"""
Get the page size of the KV cache.
"""
page_sizes = set(layer.page_size_bytes for layer in kv_cache_spec.values())
assert len(page_sizes) == 1
return page_sizes.pop()
def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model with one type of KV cache.
Divide the available memory equally among all layers.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
available_memory, page_size)
per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]
# Each layer uses a separate Tensor to store its KV cache.
kv_cache_tensors = [
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
for layer_name in kv_cache_spec
]
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=create_kv_cache_group_specs(kv_cache_spec,
grouped_layer_names),
)
num_tokens = num_blocks * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
def is_kv_cache_page_size_uniform(
kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
"""
Whether all layers in the given KVCacheSpec have the same page size.
Args:
kv_cache_spec: The KVCacheSpec of each attention layer in the model
Returns:
True if all layers have the same page size, False otherwise.
"""
page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()}
return len(page_sizes) == 1
def _get_kv_cache_config_uniform_page_size(
vllm_config: VllmConfig, kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
"""
Generates the KV cache configuration for hybrid models with multiple
attention types but still with a uniform page size (physical memory per
block per layer) for all layers.
Detailed explanation about kv cache management of hybrid models:
The layers in the models are repeated with some patterns, e.g., a model
with 10 full attention layers and 20 sliding window attention layers can be
regarded as repeating the pattern (1 * full, 2 * sw) 10 times.
The KVCacheManager allocates different block tables for each of the 3 layers
in the pattern, and repeats each of them 10 times to generate the
block_table for the 30 layers in the model.
Therefore, we can group the layers in the model into 3 kv_cache_groups, each
of which contains 10 layers in the model.
The KVCacheManager allocates the block_table for each group based on its
kv_cache spec, and the model runner applies the block table to each layer
in the group.
For example:
1. A model only uses full attention. The pattern is
(num_hidden_layers * full), so there is only one group and the block table
is shared by all layers. It is already handled by
`_get_kv_cache_config_uniform_type`.
2. A model with 10 full attention layers and 20 sliding window
attention layers. There are 3 layers in the pattern (1 * full, 2 * sw), so
there are 3 kv_cache_groups, each of which represents 10 layers.
To simplify the implementation, we make the following assumptions:
1. Physical memory per block: Must be the same across all KV cache groups.
Breaking this assumption is non-trivial due to memory fragmentation concerns
when allocating blocks of different sizes.
2. Tokens per block (block_size): Currently, we directly use
`CacheConfig.block_size` for all layers. It can be extended to vary by KV
cache group, but within each KV cache group, all layers must share the same
block size.
3. Physical memory per token per layer: This property is decided by model
config. Currently we only support models that have the same physical memory
per token per layer for all layers. Can be relaxed with a simple extension,
but still need to keep physical memory per block the same for all groups.
4. Number of layers per group: Currently assumed the same for all layers.
Can be relaxed with a simple extension, but still need to keep physical
memory per block the same for all groups.
5. Attention type within groups: All layers in a group must share the same
attention type. One exception is that, when
`--disable-hybrid-kv-cache-manager` is true, the single group for full
attention layers may also include attention layers using sliding window or
LLaMA 4 local attention. See `unify_hybrid_kv_cache_specs` for more details.
6. Support for multiple attention types: The design for most components is
general to an arbitrary number of attention types. But
`find_longest_cache_hit` only supports one attention type or two
types of full-attention plus exactly one another type. The general
implementation of this function is feasible but we don't know how to
implement it cleanly yet.
As we assume tokens per block, physical memory per token per layer, and
number of layers per group are the same now, we can ensure that physical
memory per block is the same for all groups.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The KVCacheSpec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfig
"""
# Group all layers by type_id.
# E.g., 2 full attention layers and 3 sliding window attention layers,
# -> (full.0, full.1), (sw.0, sw.1, sw.2).
same_type_layers: dict[str, list[str]] = defaultdict(list)
for layer_name, layer_spec in kv_cache_spec.items():
same_type_layers[layer_spec.type_id].append(layer_name)
# Split each group into smaller groups, to make the number of layers in each
# group identical. Add padding to the last group of each type if necessary.
# E.g., (full.0, full.1), (sw.0, sw.1, sw.2)
# split to 3 groups with 2 layers each:
# (full.0, full.1), (sw.0, sw.1), (sw.2, padding).
# FIXME(Chen): At the moment of writing this code (2025-06-02), all
# open-source hybrid model follows a n:1 pattern between different attention
# types (e.g., Gemma3 5:1 between sw and full, LLaMA4 3:1 between local and
# full), so we can use the "1" in the n:1 pattern as the group size, which
# is the minimum number of layers among all attention types. Need a better
# strategy if we want to support more complex patterns (e.g., 20 full + 30
# sw, where the group size should be 10).
group_size = min([len(layers) for layers in same_type_layers.values()])
grouped_layers = []
for layers in same_type_layers.values():
num_padding_layers = group_size - len(layers) % group_size
if num_padding_layers != group_size:
logger.warning(
"Add %d padding layers, may waste at most %.2f%% KV cache memory", # noqa
num_padding_layers,
num_padding_layers / len(layers) * 100,
)
for i in range(0, len(layers), group_size):
grouped_layers.append(layers[i:i + group_size])
kv_cache_groups = create_kv_cache_group_specs(kv_cache_spec,
grouped_layers)
# Determine how model runners should initialize the KV cache tensors.
# We will have group_size memory pools, each is shared by one layer from
# each group. As layers of different groups have different block table,
# they will use different parts of the shared Tensor.
# The memory layout in the example will be:
# full.0, sw.0, sw.2: share a Tensor with size=available_memory//2
# full.1, sw.1: share another Tensor with size=available_memory//2
page_size = get_uniform_page_size(kv_cache_spec)
num_blocks = get_num_blocks(vllm_config, group_size, available_memory,
page_size)
per_memory_pool_size = page_size * num_blocks
kv_cache_tensors = []
for i in range(group_size):
shared_by = []
for j in range(len(kv_cache_groups)):
if i < len(grouped_layers[j]):
shared_by.append(grouped_layers[j][i])
kv_cache_tensors.append(
KVCacheTensor(size=per_memory_pool_size, shared_by=shared_by))
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=kv_cache_groups,
)
# Print the KV cache size and maximum concurrency.
num_tokens = num_blocks // len(
grouped_layers) * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config
def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
"""
This function tries to convert the KV cache specs to one type if the model
is a hybrid model with multiple type of KV cache. It will convert all
SlidingWindowSpec to FullAttentionSpec if both types are present.
Args:
kv_cache_spec: The kv cache spec of each attention layer in the model
"""
def is_hybrid(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
type_ids = set(layer_spec.type_id
for layer_spec in kv_cache_spec.values())
return len(type_ids) > 1
if not is_hybrid(kv_cache_spec):
return
logger.warning(
"Hybrid KV cache manager is disabled for this hybrid model, "
"This means we do not enable any optimizations for saving KV cache "
"memory (e.g., dropping the KV cache outside the sliding window). "
"The compute of layers like sliding window is still saved.")
has_full_attention = any(
isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values())
has_sliding_window = any(
isinstance(spec, SlidingWindowSpec) for spec in kv_cache_spec.values())
if has_full_attention and has_sliding_window:
for layer_name, spec in kv_cache_spec.items():
if isinstance(spec, SlidingWindowSpec):
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=spec.block_size,
num_kv_heads=spec.num_kv_heads,
head_size=spec.head_size,
dtype=spec.dtype,
use_mla=spec.use_mla,
sliding_window=spec.sliding_window,
)
if is_hybrid(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")
def get_kv_cache_config(
vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int,
) -> KVCacheConfig:
"""
Generates the KV cache configuration for a model.
Args:
vllm_config: The global VllmConfig
kv_cache_spec: The kv cache spec of each attention layer in the model
available_memory: Memory available for KV cache in bytes.
Returns:
The generated KVCacheConfigs
"""
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
if vllm_config.scheduler_config.disable_hybrid_kv_cache_manager:
unify_hybrid_kv_cache_specs(kv_cache_spec)
if is_kv_cache_type_uniform(kv_cache_spec):
# KV cache of all layers are the same, which is true for
# most models. Allocate the same amount of memory for
# each layer.
return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec,
available_memory)
elif is_kv_cache_page_size_uniform(kv_cache_spec):
# Model contains multiple attention types, but KV cache of all layers
# have the same physical memory per block per layer. Split the layers
# into groups with the same number of layers, and thus same total page
# size.
return _get_kv_cache_config_uniform_page_size(vllm_config,
kv_cache_spec,
available_memory)
raise NotImplementedError
def unify_kv_cache_configs(kv_cache_configs: list[KVCacheConfig]):
"""
Make the KV cache configurations for each worker consistent, so that all
workers can be controlled by the same KVCacheManager.
This function verifies that the layer group of each worker are the same,
and changes the num_blocks of each worker to the smallest among all workers.
Args:
kv_cache_configs: The KV cache configurations for each worker. Will be
in-place modified to make them consistent.
"""
# Sort the kv cache groups by the type_id of their KV cache spec.
# This can avoid the inconsistency caused by the order of groups.
for kv_cache_config in kv_cache_configs:
kv_cache_config.kv_cache_groups.sort(
key=lambda x: x.kv_cache_spec.type_id)
# Verify that the groups of each rank are the same.
for kv_cache_config in kv_cache_configs[1:]:
for group_rank_0, group_rank_i in zip(
kv_cache_configs[0].kv_cache_groups,
kv_cache_config.kv_cache_groups):
assert group_rank_0.kv_cache_spec == group_rank_i.kv_cache_spec
# Change the num_blocks of each rank to the smallest among all ranks. We
# do not need to shrink the tensor size because it is valid to only use the
# first `num_blocks` blocks of the tensor.
min_num_blocks = min(kv_cache_config.num_blocks
for kv_cache_config in kv_cache_configs)
for kv_cache_config in kv_cache_configs:
kv_cache_config.num_blocks = min_num_blocks
return kv_cache_configs

View File

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
class SchedulerInterface(ABC):
@abstractmethod
def schedule(self) -> "SchedulerOutput":
"""Schedule the requests to process in this scheduling step.
The scheduling decision is made at the iteration level. Each scheduling
step corresponds to a single forward pass of the model. Therefore, this
method is called repeatedly by a busy loop in the engine.
Essentially, the scheduler produces a dictionary of {req_id: num_tokens}
that specifies how many tokens to process for each request in this
scheduling step. For example, num_tokens can be as large as the number
of prompt tokens for new requests, or it can be 1 for the requests that
are auto-regressively generating new tokens one by one. Otherwise, it
can be somewhere in between in case of chunked prefills, prefix caching,
speculative decoding, etc.
Additionally, the scheduler also returns useful data about each request
or the batch as a whole. The model runner will use this information in
preparing inputs to the model.
Returns:
A SchedulerOutput object containing information about the scheduled
requests.
"""
raise NotImplementedError
@abstractmethod
def update_from_output(
self,
scheduler_output: "SchedulerOutput",
model_runner_output: "ModelRunnerOutput",
) -> dict[int, "EngineCoreOutputs"]:
"""Update the scheduler state based on the model runner output.
This method is called after the model runner has processed the scheduled
requests. The model runner output includes generated token ids, draft
token ids for next step, etc. The scheduler uses this information to
update its states, checks the finished requests, and returns the output
for each request.
Returns:
A dict of client index to EngineCoreOutputs object containing the
outputs for each request originating from that client.
"""
raise NotImplementedError
@abstractmethod
def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue.
Args:
request: The new request being added.
"""
raise NotImplementedError
@abstractmethod
def finish_requests(
self,
request_ids: Union[str, Iterable[str]],
finished_status: "RequestStatus",
) -> None:
"""Finish the requests in the scheduler's internal queue. If the request
is not in the queue, this method will do nothing.
This method is called in two cases:
1. When the request is aborted by the client.
2. When the frontend process detects a stop string of the request after
de-tokenizing its generated tokens.
Args:
request_ids: A single or a list of request IDs.
finished_status: The finished status of the given requests.
"""
raise NotImplementedError
@abstractmethod
def get_num_unfinished_requests(self) -> int:
"""Number of unfinished requests in the scheduler's internal queue."""
raise NotImplementedError
def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests in the scheduler's
internal queue."""
return self.get_num_unfinished_requests() > 0
@abstractmethod
def has_finished_requests(self) -> bool:
"""Returns True if there are finished requests that need to be cleared.
NOTE: This is different from `not self.has_unfinished_requests()`.
The scheduler maintains an internal list of the requests finished in the
previous step. This list is returned from the next call to schedule(),
to be sent to the model runner in the next step to clear cached states
for these finished requests.
This method checks if this internal list of finished requests is
non-empty. This information is useful for DP attention.
"""
raise NotImplementedError
def has_requests(self) -> bool:
"""Returns True if there are unfinished requests, or finished requests
not yet returned in SchedulerOutputs."""
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def reset_prefix_cache(self) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
"""
raise NotImplementedError
@abstractmethod
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
raise NotImplementedError
@abstractmethod
def make_stats(self) -> Optional["SchedulerStats"]:
"""Make a SchedulerStats object for logging.
The SchedulerStats object is created for every scheduling step.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the scheduler."""
raise NotImplementedError
def get_kv_connector(self) -> Optional["KVConnectorBase_V1"]:
return None

View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata)
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.request import Request
@dataclass
class NewRequestData:
req_id: str
prompt_token_ids: list[int]
mm_inputs: list[MultiModalKwargs]
mm_hashes: list[str]
mm_positions: list[PlaceholderRange]
sampling_params: SamplingParams
block_ids: tuple[list[int], ...]
num_computed_tokens: int
lora_request: Optional[LoRARequest]
@classmethod
def from_request(
cls,
request: Request,
block_ids: tuple[list[int], ...],
) -> NewRequestData:
return cls(
req_id=request.request_id,
prompt_token_ids=request.prompt_token_ids,
mm_inputs=request.mm_inputs,
mm_hashes=request.mm_hashes,
mm_positions=request.mm_positions,
sampling_params=request.sampling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
lora_request=request.lora_request,
)
def __repr__(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids={self.prompt_token_ids},"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
# Version of __repr__ with the prompt data obfuscated
def anon_repr(self):
return (f"NewRequestData("
f"req_id={self.req_id},"
f"prompt_token_ids_len={len(self.prompt_token_ids)},"
f"mm_inputs={self.mm_inputs},"
f"mm_hashes={self.mm_hashes},"
f"mm_positions={self.mm_positions},"
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"lora_request={self.lora_request}"
")")
@dataclass
class CachedRequestData:
req_id: str
# If resumed_from_preemption is False, new_block_ids will be appended to
# the request's block IDs. If True, new_block_ids will be used as the
# request's block IDs instead of appending to the existing block IDs.
resumed_from_preemption: bool
new_token_ids: list[int]
new_block_ids: tuple[list[int], ...]
num_computed_tokens: int
@classmethod
def from_request(
cls,
request: Request,
resumed_from_preemption: bool,
new_token_ids: list[int],
new_block_ids: tuple[list[int], ...],
) -> CachedRequestData:
return cls(
req_id=request.request_id,
resumed_from_preemption=resumed_from_preemption,
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=request.num_computed_tokens,
)
@dataclass
class SchedulerOutput:
# list of the requests that are scheduled for the first time.
# We cache the request's data in each worker process, so that we don't
# need to re-send it every scheduling step.
scheduled_new_reqs: list[NewRequestData]
# list of the requests that have been scheduled before.
# Since the request's data is already cached in the worker processes,
# we only send the diff to minimize the communication cost.
scheduled_cached_reqs: list[CachedRequestData]
# req_id -> num_scheduled_tokens
# Number of tokens scheduled for each request.
num_scheduled_tokens: dict[str, int]
# Total number of tokens scheduled for all requests.
# Equal to sum(num_scheduled_tokens.values())
total_num_scheduled_tokens: int
# req_id -> spec_token_ids
# If a request does not have any spec decode tokens, it will not be
# included in the dictionary.
scheduled_spec_decode_tokens: dict[str, list[int]]
# req_id -> encoder input indices that need processing.
# E.g., if a request has [0, 1], it could mean the vision encoder needs
# to process that the request's 0-th and 1-th images in the current step.
scheduled_encoder_inputs: dict[str, list[int]]
# Number of common prefix blocks for all requests in each KV cache group.
# This can be used for cascade attention.
num_common_prefix_blocks: list[int]
# Request IDs that are finished in between the previous and the current
# steps. This is used to notify the workers about the finished requests
# so that they can free the cached states for those requests.
finished_req_ids: set[str]
# list of (req_id, encoder_input_index) tuples.
# Used to free the encoder cache.
free_encoder_input_ids: list[tuple[str, int]]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# the bitmask for the whole batch
grammar_bitmask: Optional[npt.NDArray[np.int32]]
# KV Cache Connector metadata.
kv_connector_metadata: Optional[KVConnectorMetadata] = None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.v1.request import Request, RequestStatus
def check_stop(request: Request, max_model_len: int) -> bool:
if (request.num_tokens >= max_model_len
or request.num_output_tokens >= request.max_tokens):
request.status = RequestStatus.FINISHED_LENGTH_CAPPED
return True
sampling_params = request.sampling_params
last_token_id = request.output_token_ids[-1]
if (not sampling_params.ignore_eos
and last_token_id == request.eos_token_id):
request.status = RequestStatus.FINISHED_STOPPED
return True
if last_token_id in (sampling_params.stop_token_ids or ()):
request.status = RequestStatus.FINISHED_STOPPED
request.stop_reason = last_token_id
return True
return False

View File

@@ -0,0 +1,403 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Callable
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
from vllm.v1.request import Request
class SingleTypeKVCacheManager(ABC):
"""
An abstract base class for a manager that handle the kv cache management
logic of one specific type of attention layer.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
kv_cache_group_id: int,
caching_hash_fn: Callable,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
# Mapping from request ID to blocks to track the blocks allocated
# for each request, so that we can free the blocks when the request
# is finished.
self.req_to_blocks: defaultdict[str,
list[KVCacheBlock]] = defaultdict(list)
# {req_id: The number of cached blocks for this given request}
# This is used to track the number of cached blocks for each request.
# This is only used to track the RUNNING requests, we do not track the
# data for reempted ones.
self.num_cached_block: dict[str, int] = {}
self.caching_hash_fn = caching_hash_fn
self.kv_cache_group_id = kv_cache_group_id
def get_num_blocks_to_allocate(
self, request_id: str, num_tokens: int,
new_computed_blocks: list[KVCacheBlock]) -> int:
"""
Get the number of blocks needed to be allocated for the request.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
new_computed_blocks: The new computed blocks just hitting the
prefix caching.
Returns:
The number of blocks.
"""
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = (num_required_blocks - len(new_computed_blocks) -
len(self.req_to_blocks[request_id]))
# If a computed block of a request is an eviction candidate (in the
# free queue and ref_cnt == 0), it will be changed from a free block
# to a computed block when the request is allocated, so we also count
# it as needed to be allocated.
num_evictable_computed_blocks = sum(
blk.ref_cnt == 0 and not blk.is_null
for blk in new_computed_blocks)
return num_new_blocks + num_evictable_computed_blocks
def save_new_computed_blocks(
self, request_id: str,
new_computed_blocks: list[KVCacheBlock]) -> None:
"""
Add the new computed blocks to the request.
Args:
request_id: The request ID.
new_computed_blocks: The new computed blocks just hitting the
prefix cache.
"""
if request_id not in self.num_cached_block:
# A new request.
req_blocks = self.req_to_blocks[request_id]
assert len(req_blocks) == 0
req_blocks.extend(new_computed_blocks)
self.num_cached_block[request_id] = len(new_computed_blocks)
else:
# A running request. Should not have new computed blocks.
assert len(new_computed_blocks) == 0
def allocate_new_blocks(self, request_id: str,
num_tokens: int) -> list[KVCacheBlock]:
"""
Allocate new blocks for the request to give it at least `num_tokens`
token slots.
Args:
request_id: The request ID.
num_tokens: The total number of tokens that need a slot (including
tokens that are already allocated).
Returns:
The new allocated blocks.
"""
req_blocks = self.req_to_blocks[request_id]
num_required_blocks = cdiv(num_tokens, self.block_size)
num_new_blocks = num_required_blocks - len(req_blocks)
if num_new_blocks <= 0:
return []
else:
new_blocks = self.block_pool.get_new_blocks(num_new_blocks)
req_blocks.extend(new_blocks)
return new_blocks
def cache_blocks(self, request: Request, block_hashes: list[BlockHash],
num_tokens: int) -> None:
"""
Cache the blocks for the request.
Args:
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
"""
num_cached_blocks = self.num_cached_block[request.request_id]
num_full_blocks = num_tokens // self.block_size
self.block_pool.cache_full_blocks(
request=request,
blocks=self.req_to_blocks[request.request_id],
block_hashes=block_hashes,
num_cached_blocks=num_cached_blocks,
num_full_blocks=num_full_blocks,
block_size=self.block_size,
kv_cache_group_id=self.kv_cache_group_id,
hash_fn=self.caching_hash_fn,
)
self.num_cached_block[request.request_id] = num_full_blocks
def free(self, request_id: str) -> None:
"""
Free the blocks for the request.
Args:
request_id: The request ID.
"""
# Default to [] in case a request is freed (aborted) before alloc.
req_blocks = self.req_to_blocks.pop(request_id, [])
# Free blocks in reverse order so that the tail blocks are
# freed first.
ordered_blocks = reversed(req_blocks)
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None)
@abstractmethod
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
Get the number of common prefix blocks for a request.
Args:
request_id: The request ID.
block_hashes: The block hashes of the request.
Returns:
The number of common prefix blocks.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
`max_length`. The prefix should be a common prefix hit for all the
kv cache groups in `kv_cache_group_ids`. If no cache hit is found,
return an empty list.
If eagle is enabled, drop the last matched block to force recompute the
last block to get the required hidden states for eagle drafting head.
Need to be customized for each attention type.
Args:
block_hashes: The block hashes of the request.
max_length: The maximum length of the cache hit prefix.
kv_cache_group_ids: The ids of the kv cache groups.
block_pool: The block pool.
kv_cache_spec: The kv cache spec.
use_eagle: Whether to use eagle.
Returns:
A list of cached blocks with skipped blocks replaced by null block
for each kv cache group in `kv_cache_group_ids`.
Return a list of length `len(kv_cache_group_ids)`, where the i-th
element is a list of cached blocks for the i-th kv cache group
in `kv_cache_group_ids`.
For example, sliding window manager should return a list like
([NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)]) for block size 4
and sliding window 8 and len(kv_cache_group_ids) = 1.
"""
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
"""
Remove the blocks that are no longer needed from `blocks` and free the
blocks. The removed blocks should be replaced by null_block.
Need to be customized for each attention type.
Args:
request_id: The request ID.
num_computed_tokens: The number of tokens that have been computed.
"""
raise NotImplementedError
class FullAttentionManager(SingleTypeKVCacheManager):
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, FullAttentionSpec), (
"FullAttentionManager can only be used for full attention groups")
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids)))
max_num_blocks = max_length // kv_cache_spec.block_size
for i, block_hash in zip(range(max_num_blocks), block_hashes):
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := block_pool.get_cached_block(
block_hash, kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed.append(cached)
else:
break
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# No need to remove blocks for full attention.
pass
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id]
num_common_blocks = 0
for block in blocks:
if block.ref_cnt == num_running_requests:
num_common_blocks += 1
else:
break
return num_common_blocks
class SlidingWindowManager(SingleTypeKVCacheManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec, block_pool: BlockPool,
**kwargs) -> None:
super().__init__(kv_cache_spec, block_pool, **kwargs)
self.sliding_window = kv_cache_spec.sliding_window
self._null_block = block_pool.null_block
@classmethod
def find_longest_cache_hit(
cls,
block_hashes: list[BlockHash],
max_length: int,
kv_cache_group_ids: list[int],
block_pool: BlockPool,
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups")
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
sliding_window_contiguous_blocks = cdiv(
kv_cache_spec.sliding_window - 1, kv_cache_spec.block_size)
if use_eagle:
# Need to drop the last matched block if eagle is enabled. For
# sliding window layer, we achieve this by increasing the number of
# contiguous blocks needed for prefix cache hit by one and dropping
# the last matched block.
sliding_window_contiguous_blocks += 1
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(max_num_blocks) to
# O(max_num_blocks / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
max_num_blocks = max_length // kv_cache_spec.block_size
computed_blocks = tuple([block_pool.null_block] * max_num_blocks
for _ in range(len(kv_cache_group_ids)))
num_contiguous_blocks = 0
match_found = False
# Search from right to left and early stop when a match is found.
for i in range(max_num_blocks - 1, -1, -1):
if cached_block := block_pool.get_cached_block(
block_hashes[i], kv_cache_group_ids):
for computed, cached in zip(computed_blocks, cached_block):
computed[i] = cached
num_contiguous_blocks += 1
if num_contiguous_blocks >= sliding_window_contiguous_blocks:
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
for computed in computed_blocks:
del computed[i + num_contiguous_blocks:]
match_found = True
break
else:
num_contiguous_blocks = 0
if not match_found:
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
for computed in computed_blocks:
del computed[num_contiguous_blocks:]
if use_eagle and computed_blocks[0]:
for computed in computed_blocks:
computed.pop()
return computed_blocks
def remove_skipped_blocks(self, request_id: str,
num_computed_tokens: int) -> None:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
blocks = self.req_to_blocks[request_id]
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
self.block_pool.free_blocks(removed_blocks)
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
"""
NOTE(Chen): The prefix blocks are null blocks for sliding window layers.
So it's not correct to count ref_cnt like FullAttentionManager. Return
0 here for correctness. Need to support cascade attention + sliding
window in the future.
"""
return 0
spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_manager_for_kv_cache_spec(kv_cache_spec: KVCacheSpec,
**kwargs) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, **kwargs)
return manager