Files
mr_v100-vllm/vllm/v1/core/specialized_manager.py
2025-09-15 14:58:11 +08:00

162 lines
6.5 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from abc import ABC, abstractmethod
from vllm.utils import cdiv
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
SlidingWindowSpec)
class SpecializedManager(ABC):
"""
An abstract base class for specialized managers that handle the kv
cache management logic of different attention layers.
"""
def __init__(
self,
kv_cache_spec: KVCacheSpec,
block_pool: BlockPool,
) -> None:
"""
Initializes the SpecializedManager.
Args:
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
"""
self.block_size = kv_cache_spec.block_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@abstractmethod
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
"""
Get the longest cache hit prefix of the blocks. If no cache hit is
found, return an empty list.
Args:
block_hashes: The block hashes of the request.
Returns:
A list of cached blocks with skipped blocks replaced by null block.
For example, sliding window manager should return a list like
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
sliding window 8.
"""
raise NotImplementedError
@abstractmethod
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
"""
Remove the blocks that are no longer needed from `blocks`. The removed
blocks should be replaced by null_block. Return the removed blocks in
eviction order, where the first returned block should be evicted first.
Don't free the removed blocks in this function.
Args:
blocks: The list of blocks to be updated.
num_computed_tokens: The number of tokens that have been computed.
Returns:
The removed blocks in eviction order.
"""
raise NotImplementedError
class FullAttentionManager(SpecializedManager):
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
computed_blocks: list[KVCacheBlock] = []
for block_hash in block_hashes:
# block_hashes is a chain of block hashes. If a block hash is not
# in the cached_block_hash_to_id, the following block hashes are
# not computed yet for sure.
if cached_block := self.block_pool.get_cached_block(block_hash):
computed_blocks.append(cached_block)
else:
break
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# No need to remove blocks for full attention.
return []
class SlidingWindowManager(SpecializedManager):
def __init__(self, kv_cache_spec: SlidingWindowSpec,
block_pool: BlockPool):
super().__init__(kv_cache_spec, block_pool)
self.sliding_window = kv_cache_spec.sliding_window
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
self.sliding_window_contiguous_blocks = cdiv(
(kv_cache_spec.sliding_window - 1), self.block_size)
self._null_block = block_pool.null_block
def find_longest_cache_hit(
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
# optimize the time complexity from O(len(block_hashes)) to
# O(len(block_hashes) / sliding_window_contiguous_blocks +
# sliding_window_contiguous_blocks),
# which is good for low cache hit rate scenarios.
computed_blocks = [self._null_block] * len(block_hashes)
num_contiguous_blocks = 0
# Search from right to left and early stop when a match is found.
for i in range(len(block_hashes) - 1, -1, -1):
if cached_block := self.block_pool.get_cached_block(
block_hashes[i]):
computed_blocks[i] = cached_block
num_contiguous_blocks += 1
if (num_contiguous_blocks
>= self.sliding_window_contiguous_blocks):
# Trim the trailing blocks.
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
# when sliding_window_contiguous_blocks=2.
del computed_blocks[i + num_contiguous_blocks:]
return computed_blocks
else:
num_contiguous_blocks = 0
# The first `num_contiguous_blocks` is a cache hit even if
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
del computed_blocks[num_contiguous_blocks:]
return computed_blocks
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
num_computed_tokens: int) -> list[KVCacheBlock]:
# Remove the blocks that are no longer be in the sliding window and
# skipped during the attention computation.
last_useful_token = num_computed_tokens - self.sliding_window + 1
last_useful_block = last_useful_token // self.block_size
removed_blocks: list[KVCacheBlock] = []
for i in range(last_useful_block - 1, -1, -1):
if blocks[i] == self._null_block:
# If the block is already a null block, the blocks before it
# should also have been set to null blocks by the previous calls
# to this function.
break
removed_blocks.append(blocks[i])
blocks[i] = self._null_block
return removed_blocks
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
FullAttentionSpec: FullAttentionManager,
SlidingWindowSpec: SlidingWindowManager,
}
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
block_pool: BlockPool) -> SpecializedManager:
manager_class = spec_manager_map[type(kv_cache_spec)]
manager = manager_class(kv_cache_spec, block_pool)
return manager