162 lines
6.5 KiB
Python
162 lines
6.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
from abc import ABC, abstractmethod
|
|
|
|
from vllm.utils import cdiv
|
|
from vllm.v1.core.block_pool import BlockPool
|
|
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
|
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
|
SlidingWindowSpec)
|
|
|
|
|
|
class SpecializedManager(ABC):
|
|
"""
|
|
An abstract base class for specialized managers that handle the kv
|
|
cache management logic of different attention layers.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: KVCacheSpec,
|
|
block_pool: BlockPool,
|
|
) -> None:
|
|
"""
|
|
Initializes the SpecializedManager.
|
|
Args:
|
|
kv_cache_spec: The kv_cache_spec for this manager.
|
|
block_pool: The block pool.
|
|
"""
|
|
|
|
self.block_size = kv_cache_spec.block_size
|
|
self.kv_cache_spec = kv_cache_spec
|
|
self.block_pool = block_pool
|
|
|
|
@abstractmethod
|
|
def find_longest_cache_hit(
|
|
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
|
"""
|
|
Get the longest cache hit prefix of the blocks. If no cache hit is
|
|
found, return an empty list.
|
|
|
|
Args:
|
|
block_hashes: The block hashes of the request.
|
|
Returns:
|
|
A list of cached blocks with skipped blocks replaced by null block.
|
|
For example, sliding window manager should return a list like
|
|
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
|
|
sliding window 8.
|
|
"""
|
|
|
|
raise NotImplementedError
|
|
|
|
@abstractmethod
|
|
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
|
num_computed_tokens: int) -> list[KVCacheBlock]:
|
|
"""
|
|
Remove the blocks that are no longer needed from `blocks`. The removed
|
|
blocks should be replaced by null_block. Return the removed blocks in
|
|
eviction order, where the first returned block should be evicted first.
|
|
Don't free the removed blocks in this function.
|
|
|
|
Args:
|
|
blocks: The list of blocks to be updated.
|
|
num_computed_tokens: The number of tokens that have been computed.
|
|
Returns:
|
|
The removed blocks in eviction order.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
|
|
class FullAttentionManager(SpecializedManager):
|
|
|
|
def find_longest_cache_hit(
|
|
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
|
computed_blocks: list[KVCacheBlock] = []
|
|
for block_hash in block_hashes:
|
|
# block_hashes is a chain of block hashes. If a block hash is not
|
|
# in the cached_block_hash_to_id, the following block hashes are
|
|
# not computed yet for sure.
|
|
if cached_block := self.block_pool.get_cached_block(block_hash):
|
|
computed_blocks.append(cached_block)
|
|
else:
|
|
break
|
|
return computed_blocks
|
|
|
|
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
|
num_computed_tokens: int) -> list[KVCacheBlock]:
|
|
# No need to remove blocks for full attention.
|
|
return []
|
|
|
|
|
|
class SlidingWindowManager(SpecializedManager):
|
|
|
|
def __init__(self, kv_cache_spec: SlidingWindowSpec,
|
|
block_pool: BlockPool):
|
|
super().__init__(kv_cache_spec, block_pool)
|
|
self.sliding_window = kv_cache_spec.sliding_window
|
|
# The number of contiguous blocks needed for prefix cache hit.
|
|
# -1 since the input token itself is also included in the window
|
|
self.sliding_window_contiguous_blocks = cdiv(
|
|
(kv_cache_spec.sliding_window - 1), self.block_size)
|
|
self._null_block = block_pool.null_block
|
|
|
|
def find_longest_cache_hit(
|
|
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
|
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
|
# optimize the time complexity from O(len(block_hashes)) to
|
|
# O(len(block_hashes) / sliding_window_contiguous_blocks +
|
|
# sliding_window_contiguous_blocks),
|
|
# which is good for low cache hit rate scenarios.
|
|
computed_blocks = [self._null_block] * len(block_hashes)
|
|
num_contiguous_blocks = 0
|
|
|
|
# Search from right to left and early stop when a match is found.
|
|
for i in range(len(block_hashes) - 1, -1, -1):
|
|
if cached_block := self.block_pool.get_cached_block(
|
|
block_hashes[i]):
|
|
computed_blocks[i] = cached_block
|
|
num_contiguous_blocks += 1
|
|
if (num_contiguous_blocks
|
|
>= self.sliding_window_contiguous_blocks):
|
|
# Trim the trailing blocks.
|
|
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
|
# when sliding_window_contiguous_blocks=2.
|
|
del computed_blocks[i + num_contiguous_blocks:]
|
|
return computed_blocks
|
|
else:
|
|
num_contiguous_blocks = 0
|
|
# The first `num_contiguous_blocks` is a cache hit even if
|
|
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
|
del computed_blocks[num_contiguous_blocks:]
|
|
return computed_blocks
|
|
|
|
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
|
num_computed_tokens: int) -> list[KVCacheBlock]:
|
|
# Remove the blocks that are no longer be in the sliding window and
|
|
# skipped during the attention computation.
|
|
last_useful_token = num_computed_tokens - self.sliding_window + 1
|
|
last_useful_block = last_useful_token // self.block_size
|
|
|
|
removed_blocks: list[KVCacheBlock] = []
|
|
for i in range(last_useful_block - 1, -1, -1):
|
|
if blocks[i] == self._null_block:
|
|
# If the block is already a null block, the blocks before it
|
|
# should also have been set to null blocks by the previous calls
|
|
# to this function.
|
|
break
|
|
removed_blocks.append(blocks[i])
|
|
blocks[i] = self._null_block
|
|
return removed_blocks
|
|
|
|
|
|
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
|
|
FullAttentionSpec: FullAttentionManager,
|
|
SlidingWindowSpec: SlidingWindowManager,
|
|
}
|
|
|
|
|
|
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
|
|
block_pool: BlockPool) -> SpecializedManager:
|
|
manager_class = spec_manager_map[type(kv_cache_spec)]
|
|
manager = manager_class(kv_cache_spec, block_pool)
|
|
return manager
|