init src 0.9.2
This commit is contained in:
398
vllm/v1/core/kv_cache_manager.py
Normal file
398
vllm/v1/core/kv_cache_manager.py
Normal file
@@ -0,0 +1,398 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||
hash_request_tokens)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVCacheBlocks:
|
||||
"""
|
||||
The allocation result of KVCacheManager, work as the interface between
|
||||
Scheduler and KVCacheManager, to hide KVCacheManager's internal data
|
||||
structure from the Scheduler.
|
||||
"""
|
||||
blocks: tuple[list[KVCacheBlock], ...]
|
||||
"""
|
||||
blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens.
|
||||
We don't use block of tokens as the outer dimension because it assumes all
|
||||
kv_cache_groups have the same number of blocks, which is true for now but
|
||||
will be broken if we want to give different block_size to different
|
||||
kv_cache_groups in the future.
|
||||
"""
|
||||
|
||||
def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks":
|
||||
"""Adds two KVCacheBlocks instances."""
|
||||
return KVCacheBlocks(
|
||||
tuple(blk1 + blk2
|
||||
for blk1, blk2 in zip(self.blocks, other.blocks)))
|
||||
|
||||
def get_block_ids(self) -> tuple[list[int], ...]:
|
||||
"""
|
||||
Converts the KVCacheBlocks instance to block_ids.
|
||||
|
||||
Returns:
|
||||
tuple[list[int], ...]: A tuple of lists where
|
||||
* the outer tuple corresponds to KV cache groups
|
||||
* each inner list contains the block_ids of the blocks in that group
|
||||
"""
|
||||
return tuple([blk.block_id for blk in group] for group in self.blocks)
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
assert len(self.blocks) == 1, "Only one group is supported"
|
||||
return [
|
||||
block.block_id for block in self.blocks[0]
|
||||
if block.block_hash is None
|
||||
]
|
||||
|
||||
def new_empty(self) -> "KVCacheBlocks":
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))))
|
||||
|
||||
|
||||
class KVCacheManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
max_model_len: int,
|
||||
enable_caching: bool = True,
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||
self.use_eagle = use_eagle
|
||||
self.log_stats = log_stats
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
|
||||
self.block_size: Optional[int] = None
|
||||
if self.enable_caching:
|
||||
assert len(
|
||||
set(g.kv_cache_spec.block_size
|
||||
for g in kv_cache_config.kv_cache_groups)
|
||||
) == 1, "Only one block size is supported for now"
|
||||
self.block_size = kv_cache_config.kv_cache_groups[
|
||||
0].kv_cache_spec.block_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
max_model_len=self.max_model_len,
|
||||
use_eagle=self.use_eagle,
|
||||
enable_caching=enable_caching,
|
||||
caching_hash_fn=self.caching_hash_fn,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
self.kv_cache_config = kv_cache_config
|
||||
|
||||
# Mapping from request ID to kv block hashes.
|
||||
# This is to avoid recomputing the block hashes for each call of
|
||||
# `get_computed_blocks` or `allocate_slots`.
|
||||
self.req_to_block_hashes: defaultdict[
|
||||
str, list[BlockHash]] = defaultdict(list)
|
||||
|
||||
@property
|
||||
def usage(self) -> float:
|
||||
"""Get the KV cache usage.
|
||||
|
||||
Returns:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return self.block_pool.get_usage()
|
||||
|
||||
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||
"""Get (and reset) the prefix cache stats.
|
||||
|
||||
Returns:
|
||||
The current prefix caching stats, or None if logging is disabled.
|
||||
"""
|
||||
if not self.log_stats:
|
||||
return None
|
||||
stats = self.prefix_cache_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats()
|
||||
return stats
|
||||
|
||||
def get_computed_blocks(self,
|
||||
request: Request) -> tuple[KVCacheBlocks, int]:
|
||||
"""Get the computed (cached) blocks for the request.
|
||||
Note that the computed blocks must be full.
|
||||
|
||||
Args:
|
||||
request: The request to get the computed blocks.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A list of blocks that are computed for the request.
|
||||
- The number of computed tokens.
|
||||
"""
|
||||
# Prefix caching is disabled or
|
||||
# When the request requires prompt logprobs, we skip prefix caching.
|
||||
if (not self.enable_caching
|
||||
or (request.sampling_params is not None
|
||||
and request.sampling_params.prompt_logprobs is not None)):
|
||||
return self.create_empty_block_list(), 0
|
||||
|
||||
# The block hashes for the request may already be computed
|
||||
# if the scheduler has tried to schedule the request before.
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
if not block_hashes:
|
||||
assert self.block_size is not None
|
||||
block_hashes = hash_request_tokens(self.caching_hash_fn,
|
||||
self.block_size, request)
|
||||
self.req_to_block_hashes[request.request_id] = block_hashes
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.requests += 1
|
||||
|
||||
# NOTE: When all tokens hit the cache, we must recompute the last token
|
||||
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
|
||||
# This can trigger recomputation of an entire block, rather than just
|
||||
# the single last token, because allocate_slots() requires
|
||||
# num_computed_tokens to be block-size aligned. Removing this limitation
|
||||
# could slightly improve performance in the future.
|
||||
max_cache_hit_length = request.num_tokens - 1
|
||||
computed_blocks, num_new_computed_tokens = (
|
||||
self.coordinator.find_longest_cache_hit(block_hashes,
|
||||
max_cache_hit_length))
|
||||
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.queries += request.num_tokens
|
||||
self.prefix_cache_stats.hits += num_new_computed_tokens
|
||||
|
||||
return KVCacheBlocks(computed_blocks), num_new_computed_tokens
|
||||
|
||||
def allocate_slots(
|
||||
self,
|
||||
request: Request,
|
||||
num_new_tokens: int,
|
||||
num_new_computed_tokens: int = 0,
|
||||
new_computed_blocks: Optional[KVCacheBlocks] = None,
|
||||
num_draft_tokens: int = 0,
|
||||
num_lookahead_tokens: int = 0,
|
||||
delay_cache_blocks: bool = False,
|
||||
) -> Optional[KVCacheBlocks]:
|
||||
"""Add slots for a request with new tokens to append.
|
||||
|
||||
Args:
|
||||
request: The request to allocate slots.
|
||||
num_new_tokens: The number of tokens to allocate, including external
|
||||
tokens. Note that this does not include tokens that have
|
||||
already been computed locally (i.e. new_computed_blocks).
|
||||
num_new_computed_tokens: The number of new computed tokens just
|
||||
hitting the prefix caching, excluding external tokens.
|
||||
new_computed_blocks: The cached blocks for the above new computed
|
||||
tokens.
|
||||
num_lookahead_tokens: The number of speculative tokens to allocate.
|
||||
This is used by spec decode proposers with kv-cache such
|
||||
as eagle.
|
||||
delay_cache_blocks: Whether to skip caching the blocks. This is
|
||||
used by P/D when allocating blocks used in a KV transfer
|
||||
which will complete in a future step.
|
||||
|
||||
Blocks layout:
|
||||
```
|
||||
-----------------------------------------------------------------------
|
||||
| < computed > | < new computed > | < new > | < pre-allocated > |
|
||||
-----------------------------------------------------------------------
|
||||
| < required > |
|
||||
--------------------------------------------------
|
||||
| < full > |
|
||||
------------------------------------------------
|
||||
| <new full> |
|
||||
--------------
|
||||
```
|
||||
The following *_blocks are illustrated in this layout.
|
||||
|
||||
Returns:
|
||||
A list of new allocated blocks.
|
||||
"""
|
||||
if num_new_tokens == 0:
|
||||
raise ValueError("num_new_tokens must be greater than 0")
|
||||
|
||||
if new_computed_blocks is not None:
|
||||
new_computed_block_list = new_computed_blocks.blocks
|
||||
else:
|
||||
new_computed_block_list = tuple(
|
||||
[] for _ in range(len(self.kv_cache_config.kv_cache_groups)))
|
||||
|
||||
# Free the blocks that are skipped during the attention computation
|
||||
# (e.g., tokens outside the sliding window).
|
||||
# We can do this even if we cannot schedule this request due to
|
||||
# insufficient free blocks.
|
||||
# Should call this function before allocating new blocks to reduce
|
||||
# the number of evicted blocks.
|
||||
self.coordinator.remove_skipped_blocks(request.request_id,
|
||||
request.num_computed_tokens)
|
||||
|
||||
# The number of computed tokens is the number of computed tokens plus
|
||||
# the new prefix caching hits
|
||||
num_computed_tokens = (request.num_computed_tokens +
|
||||
num_new_computed_tokens)
|
||||
num_tokens_need_slot = min(
|
||||
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
|
||||
self.max_model_len)
|
||||
|
||||
num_blocks_to_allocate = self.coordinator.get_num_blocks_to_allocate(
|
||||
request_id=request.request_id,
|
||||
num_tokens=num_tokens_need_slot,
|
||||
new_computed_blocks=new_computed_block_list,
|
||||
)
|
||||
|
||||
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||
# Cannot allocate new blocks
|
||||
return None
|
||||
|
||||
# Touch the computed blocks to make sure they won't be evicted.
|
||||
if self.enable_caching:
|
||||
self.block_pool.touch(new_computed_block_list)
|
||||
else:
|
||||
assert not any(new_computed_block_list), (
|
||||
"Computed blocks should be empty when "
|
||||
"prefix caching is disabled")
|
||||
|
||||
# Append the new computed blocks to the request blocks until now to
|
||||
# avoid the case where the new blocks cannot be allocated.
|
||||
self.coordinator.save_new_computed_blocks(request.request_id,
|
||||
new_computed_block_list)
|
||||
|
||||
new_blocks = self.coordinator.allocate_new_blocks(
|
||||
request.request_id, num_tokens_need_slot)
|
||||
|
||||
# P/D: delay caching blocks if we have to recv from
|
||||
# remote. Update state for locally cached blocks.
|
||||
if not self.enable_caching or delay_cache_blocks:
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
# Speculated tokens might be rejected in the future, so we does
|
||||
# not cache any speculated tokens. We only cache blocks with
|
||||
# generated (accepted) tokens.
|
||||
self.coordinator.cache_blocks(
|
||||
request, self.req_to_block_hashes[request.request_id],
|
||||
num_computed_tokens + num_new_tokens - num_draft_tokens)
|
||||
|
||||
return KVCacheBlocks(new_blocks)
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free the blocks allocated for the request.
|
||||
We free the blocks in reverse order so that he tail blocks are evicted
|
||||
first when caching is enabled.
|
||||
|
||||
Args:
|
||||
request: The request to free the blocks.
|
||||
"""
|
||||
self.coordinator.free(request.request_id)
|
||||
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
"""Reset prefix cache. This function may be used in RLHF
|
||||
flows to invalidate prefix caching after the weights are updated,
|
||||
or used for resetting prefix caching status for benchmarking.
|
||||
|
||||
Returns:
|
||||
bool: True if the prefix cache is successfully reset,
|
||||
False otherwise.
|
||||
"""
|
||||
if not self.block_pool.reset_prefix_cache():
|
||||
return False
|
||||
if self.log_stats:
|
||||
assert self.prefix_cache_stats is not None
|
||||
self.prefix_cache_stats.reset = True
|
||||
return True
|
||||
|
||||
def get_num_common_prefix_blocks(
|
||||
self,
|
||||
request: Request,
|
||||
num_running_requests: int,
|
||||
) -> list[int]:
|
||||
"""Calculate the number of common prefix blocks shared by all requests
|
||||
in the RUNNING state for each kv cache group.
|
||||
|
||||
The function determines this by selecting any request and iterating
|
||||
through its blocks. A block is considered a common prefix block if its
|
||||
`ref_cnt` equals the total number of requests in the RUNNING state.
|
||||
|
||||
NOTE(woosuk): The number of requests in the RUNNING state is **greater
|
||||
than or equal to** the number of requests scheduled in the current step.
|
||||
This is because the RUNNING state only indicates that:
|
||||
1. The request has not yet finished, and
|
||||
2. The request holds its blocks unfreed.
|
||||
|
||||
While all scheduled requests must be in the RUNNING state, the inverse
|
||||
is not necessarily true. There may be RUNNING requests that are not
|
||||
scheduled in the current step.
|
||||
|
||||
This can result in an edge case where the number of common prefix blocks
|
||||
is 0, even though all scheduled requests share a common prefix. This
|
||||
occurs because there may be unscheduled RUNNING requests that do not
|
||||
share the common prefix. Currently, this case cannot be easily detected,
|
||||
so the function returns 0 in such cases.
|
||||
|
||||
Args:
|
||||
request: Any request in the RUNNING state, used to identify the
|
||||
common prefix blocks.
|
||||
num_running_requests: The total number of requests in the RUNNING
|
||||
state. This can be different from the number of scheduled
|
||||
requests in the current step.
|
||||
|
||||
Returns:
|
||||
list[int]: The number of common prefix blocks for each kv cache
|
||||
group.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
return self.coordinator.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
|
||||
NOTE: Unlike `free`, this method should be called only when the request
|
||||
is finished, not when it is preempted.
|
||||
"""
|
||||
self.req_to_block_hashes.pop(request.request_id, None)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_block_ids(self, request_id: str) -> tuple[list[int], ...]:
|
||||
"""Get the block ids of a request."""
|
||||
return KVCacheBlocks(
|
||||
self.coordinator.get_blocks(request_id)).get_block_ids()
|
||||
|
||||
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
|
||||
"""Cache the blocks for the request, if enabled."""
|
||||
if self.enable_caching:
|
||||
block_hashes = self.req_to_block_hashes[request.request_id]
|
||||
self.coordinator.cache_blocks(request, block_hashes,
|
||||
num_computed_tokens)
|
||||
|
||||
def create_empty_block_list(self) -> KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return KVCacheBlocks(tuple([]
|
||||
for _ in range(self.num_kv_cache_groups)))
|
||||
Reference in New Issue
Block a user