Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
161
vllm/v1/core/specialized_manager.py
Normal file
161
vllm/v1/core/specialized_manager.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.utils import cdiv
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
from vllm.v1.core.kv_cache_utils import BlockHashType, KVCacheBlock
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheSpec,
|
||||
SlidingWindowSpec)
|
||||
|
||||
|
||||
class SpecializedManager(ABC):
|
||||
"""
|
||||
An abstract base class for specialized managers that handle the kv
|
||||
cache management logic of different attention layers.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SpecializedManager.
|
||||
Args:
|
||||
kv_cache_spec: The kv_cache_spec for this manager.
|
||||
block_pool: The block pool.
|
||||
"""
|
||||
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
@abstractmethod
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks. If no cache hit is
|
||||
found, return an empty list.
|
||||
|
||||
Args:
|
||||
block_hashes: The block hashes of the request.
|
||||
Returns:
|
||||
A list of cached blocks with skipped blocks replaced by null block.
|
||||
For example, sliding window manager should return a list like
|
||||
[NULL, NULL, KVCacheBlock(7), KVCacheBlock(8)] for block size 4 and
|
||||
sliding window 8.
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
"""
|
||||
Remove the blocks that are no longer needed from `blocks`. The removed
|
||||
blocks should be replaced by null_block. Return the removed blocks in
|
||||
eviction order, where the first returned block should be evicted first.
|
||||
Don't free the removed blocks in this function.
|
||||
|
||||
Args:
|
||||
blocks: The list of blocks to be updated.
|
||||
num_computed_tokens: The number of tokens that have been computed.
|
||||
Returns:
|
||||
The removed blocks in eviction order.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FullAttentionManager(SpecializedManager):
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
computed_blocks: list[KVCacheBlock] = []
|
||||
for block_hash in block_hashes:
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
# in the cached_block_hash_to_id, the following block hashes are
|
||||
# not computed yet for sure.
|
||||
if cached_block := self.block_pool.get_cached_block(block_hash):
|
||||
computed_blocks.append(cached_block)
|
||||
else:
|
||||
break
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
# No need to remove blocks for full attention.
|
||||
return []
|
||||
|
||||
|
||||
class SlidingWindowManager(SpecializedManager):
|
||||
|
||||
def __init__(self, kv_cache_spec: SlidingWindowSpec,
|
||||
block_pool: BlockPool):
|
||||
super().__init__(kv_cache_spec, block_pool)
|
||||
self.sliding_window = kv_cache_spec.sliding_window
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
self.sliding_window_contiguous_blocks = cdiv(
|
||||
(kv_cache_spec.sliding_window - 1), self.block_size)
|
||||
self._null_block = block_pool.null_block
|
||||
|
||||
def find_longest_cache_hit(
|
||||
self, block_hashes: list[BlockHashType]) -> list[KVCacheBlock]:
|
||||
# TODO: reduce i by sliding_window_contiguous_blocks when cache miss, to
|
||||
# optimize the time complexity from O(len(block_hashes)) to
|
||||
# O(len(block_hashes) / sliding_window_contiguous_blocks +
|
||||
# sliding_window_contiguous_blocks),
|
||||
# which is good for low cache hit rate scenarios.
|
||||
computed_blocks = [self._null_block] * len(block_hashes)
|
||||
num_contiguous_blocks = 0
|
||||
|
||||
# Search from right to left and early stop when a match is found.
|
||||
for i in range(len(block_hashes) - 1, -1, -1):
|
||||
if cached_block := self.block_pool.get_cached_block(
|
||||
block_hashes[i]):
|
||||
computed_blocks[i] = cached_block
|
||||
num_contiguous_blocks += 1
|
||||
if (num_contiguous_blocks
|
||||
>= self.sliding_window_contiguous_blocks):
|
||||
# Trim the trailing blocks.
|
||||
# E.g., [NULL, NULL, 8, 3, NULL, 9] -> [NULL, NULL, 8, 3]
|
||||
# when sliding_window_contiguous_blocks=2.
|
||||
del computed_blocks[i + num_contiguous_blocks:]
|
||||
return computed_blocks
|
||||
else:
|
||||
num_contiguous_blocks = 0
|
||||
# The first `num_contiguous_blocks` is a cache hit even if
|
||||
# `num_contiguous_blocks < sliding_window_contiguous_blocks`.
|
||||
del computed_blocks[num_contiguous_blocks:]
|
||||
return computed_blocks
|
||||
|
||||
def remove_skipped_blocks(self, blocks: list[KVCacheBlock],
|
||||
num_computed_tokens: int) -> list[KVCacheBlock]:
|
||||
# Remove the blocks that are no longer be in the sliding window and
|
||||
# skipped during the attention computation.
|
||||
last_useful_token = num_computed_tokens - self.sliding_window + 1
|
||||
last_useful_block = last_useful_token // self.block_size
|
||||
|
||||
removed_blocks: list[KVCacheBlock] = []
|
||||
for i in range(last_useful_block - 1, -1, -1):
|
||||
if blocks[i] == self._null_block:
|
||||
# If the block is already a null block, the blocks before it
|
||||
# should also have been set to null blocks by the previous calls
|
||||
# to this function.
|
||||
break
|
||||
removed_blocks.append(blocks[i])
|
||||
blocks[i] = self._null_block
|
||||
return removed_blocks
|
||||
|
||||
|
||||
spec_manager_map: dict[type[KVCacheSpec], type[SpecializedManager]] = {
|
||||
FullAttentionSpec: FullAttentionManager,
|
||||
SlidingWindowSpec: SlidingWindowManager,
|
||||
}
|
||||
|
||||
|
||||
def get_specialized_manager(kv_cache_spec: KVCacheSpec,
|
||||
block_pool: BlockPool) -> SpecializedManager:
|
||||
manager_class = spec_manager_map[type(kv_cache_spec)]
|
||||
manager = manager_class(kv_cache_spec, block_pool)
|
||||
return manager
|
||||
Reference in New Issue
Block a user