init
This commit is contained in:
0
vllm/core/__init__.py
Normal file
0
vllm/core/__init__.py
Normal file
0
vllm/core/block/__init__.py
Normal file
0
vllm/core/block/__init__.py
Normal file
295
vllm/core/block/block_table.py
Normal file
295
vllm/core/block/block_table.py
Normal file
@@ -0,0 +1,295 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from vllm.core.block.interfaces import Block, DeviceAwareBlockAllocator
|
||||
from vllm.utils import Device, cdiv, chunk_list
|
||||
|
||||
|
||||
class BlockTable:
|
||||
"""A class to manage blocks for a specific sequence.
|
||||
|
||||
The BlockTable maps a sequence of tokens to a list of blocks, where each
|
||||
block represents a contiguous memory allocation for a portion of the
|
||||
sequence. The blocks are managed by a DeviceAwareBlockAllocator, which is
|
||||
responsible for allocating and freeing memory for the blocks.
|
||||
|
||||
Args:
|
||||
block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
block_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]], optional): An optional list of existing
|
||||
blocks to initialize the BlockTable with. If not provided, an empty
|
||||
BlockTable is created.
|
||||
|
||||
Attributes:
|
||||
_block_size (int): The maximum number of tokens that can be stored in a
|
||||
single block.
|
||||
_allocator (DeviceAwareBlockAllocator): The block allocator used to
|
||||
manage memory for the blocks.
|
||||
_blocks (Optional[List[Block]]): The list of blocks managed by this
|
||||
BlockTable.
|
||||
_num_full_slots (int): The number of tokens currently stored in the
|
||||
blocks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
block_allocator: DeviceAwareBlockAllocator,
|
||||
_blocks: Optional[List[Block]] = None,
|
||||
):
|
||||
self._block_size = block_size
|
||||
self._allocator = block_allocator
|
||||
if _blocks is None:
|
||||
_blocks = []
|
||||
self._blocks: List[Block] = _blocks
|
||||
|
||||
# Use helper method instead of directly calculating, as blocks
|
||||
# may not be allocated.
|
||||
self._num_full_slots = len(self._get_all_token_ids())
|
||||
|
||||
@staticmethod
|
||||
def get_num_required_blocks(token_ids: List[int], block_size: int) -> int:
|
||||
"""Calculates the minimum number of blocks required to store a given
|
||||
sequence of token IDs.
|
||||
|
||||
This assumes worst-case scenario, where every block requires a new
|
||||
allocation (e.g. ignoring prefix caching).
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
block_size (int): The maximum number of tokens that can be stored in
|
||||
a single block.
|
||||
|
||||
Returns:
|
||||
int: The minimum number of blocks required to store the given
|
||||
sequence of token IDs.
|
||||
"""
|
||||
return cdiv(len(token_ids), block_size)
|
||||
|
||||
def allocate(self,
|
||||
token_ids: List[int],
|
||||
device: Device = Device.GPU) -> None:
|
||||
"""Allocates memory blocks for storing the given sequence of token IDs.
|
||||
|
||||
This method allocates the required number of blocks to store the given
|
||||
sequence of token IDs.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be stored.
|
||||
device (Device, optional): The device on which the blocks should be
|
||||
allocated. Defaults to Device.GPU.
|
||||
"""
|
||||
assert not self._is_allocated
|
||||
assert token_ids
|
||||
self._blocks = self._allocate_blocks_for_token_ids(prev_block=None,
|
||||
token_ids=token_ids,
|
||||
device=device)
|
||||
self._num_full_slots = len(token_ids)
|
||||
|
||||
def append_token_ids(self,
|
||||
token_ids: List[int],
|
||||
num_lookahead_slots: int = 0) -> None:
|
||||
"""Appends a sequence of token IDs to the existing blocks in the
|
||||
BlockTable.
|
||||
|
||||
This method appends the given sequence of token IDs to the existing
|
||||
blocks in the BlockTable. If there is not enough space in the existing
|
||||
blocks, new blocks are allocated using the `ensure_num_empty_slots`
|
||||
method to accommodate the additional tokens.
|
||||
|
||||
The token IDs are divided into chunks of size `block_size` (except for
|
||||
the first chunk, which may be smaller), and each chunk is appended to a
|
||||
separate block.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The sequence of token IDs to be appended.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert len(self._blocks) > 0
|
||||
|
||||
self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
|
||||
num_lookahead_slots)
|
||||
|
||||
blocks = self._blocks[self._num_full_slots // self._block_size:]
|
||||
token_blocks = self._chunk_token_blocks_for_append(token_ids)
|
||||
|
||||
for block, token_block in zip(blocks, token_blocks):
|
||||
block.append_token_ids(token_block)
|
||||
|
||||
self._num_full_slots += len(token_ids)
|
||||
|
||||
def ensure_num_empty_slots(self, num_empty_slots: int) -> None:
|
||||
"""Ensures that the BlockTable has at least the specified number of
|
||||
empty slots available.
|
||||
|
||||
This method checks if the BlockTable has enough empty slots (i.e.,
|
||||
available space) to accommodate the requested number of tokens. If not,
|
||||
it allocates additional blocks on the GPU to ensure that the required
|
||||
number of empty slots is available.
|
||||
|
||||
Args:
|
||||
num_empty_slots (int): The minimum number of empty slots required.
|
||||
"""
|
||||
# Currently the block table only supports
|
||||
# appending tokens to GPU blocks.
|
||||
device = Device.GPU
|
||||
assert self._is_allocated
|
||||
|
||||
if self._num_empty_slots >= num_empty_slots:
|
||||
return
|
||||
|
||||
slots_to_allocate = num_empty_slots - self._num_empty_slots
|
||||
blocks_to_allocate = cdiv(slots_to_allocate, self._block_size)
|
||||
|
||||
for _ in range(blocks_to_allocate):
|
||||
assert len(self._blocks) > 0
|
||||
self._blocks.append(
|
||||
self._allocator.allocate_mutable(prev_block=self._blocks[-1],
|
||||
device=device))
|
||||
|
||||
def fork(self) -> "BlockTable":
|
||||
"""Creates a new BlockTable instance with a copy of the blocks from the
|
||||
current instance.
|
||||
|
||||
This method creates a new BlockTable instance with the same block size,
|
||||
block allocator, and a copy of the blocks from the current instance. The
|
||||
new BlockTable has its own independent set of blocks, but shares the
|
||||
same underlying memory allocation with the original BlockTable.
|
||||
|
||||
Returns:
|
||||
BlockTable: A new BlockTable instance with a copy of the blocks from
|
||||
the current instance.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
assert len(self._blocks) > 0
|
||||
forked_blocks = self._allocator.fork(self._blocks[-1])
|
||||
return BlockTable(
|
||||
block_size=self._block_size,
|
||||
block_allocator=self._allocator,
|
||||
_blocks=forked_blocks,
|
||||
)
|
||||
|
||||
def free(self) -> None:
|
||||
"""Frees the memory occupied by the blocks in the BlockTable.
|
||||
|
||||
This method iterates over all the blocks in the `_blocks` list and calls
|
||||
the `free` method of the `_allocator` object to release the memory
|
||||
occupied by each block. After freeing all the blocks, the `_blocks` list
|
||||
is set to `None`.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
for block in self._blocks:
|
||||
self._allocator.free(block)
|
||||
self._blocks = []
|
||||
|
||||
@property
|
||||
def physical_block_ids(self) -> List[Optional[int]]:
|
||||
"""Returns a list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
|
||||
This property returns a list of integers, where each integer represents
|
||||
the physical block index of a corresponding block in the `_blocks` list.
|
||||
The physical block index is a unique identifier for the memory location
|
||||
occupied by the block.
|
||||
|
||||
Returns:
|
||||
List[int]: A list of physical block indices for the blocks in the
|
||||
BlockTable.
|
||||
"""
|
||||
assert self._is_allocated
|
||||
return [block.block_id for block in self._blocks]
|
||||
|
||||
def get_unseen_token_ids(self, sequence_token_ids: List[int]) -> List[int]:
|
||||
"""Get the number of "unseen" tokens in the sequence.
|
||||
|
||||
Unseen tokens are tokens in the sequence corresponding to this block
|
||||
table, but are not yet appended to this block table.
|
||||
|
||||
Args:
|
||||
sequence_token_ids (List[int]): The list of token ids in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
List[int]: The postfix of sequence_token_ids that has not yet been
|
||||
appended to the block table.
|
||||
"""
|
||||
|
||||
# Since the block table is append-only, the unseen token ids are the
|
||||
# ones after the appended ones.
|
||||
return sequence_token_ids[self.num_full_slots:]
|
||||
|
||||
def _allocate_blocks_for_token_ids(self, prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Device) -> List[Block]:
|
||||
blocks = []
|
||||
for block_token_ids in chunk_list(token_ids, self._block_size):
|
||||
if len(block_token_ids) == self._block_size:
|
||||
# If the block is full, create an immutable block.
|
||||
prev_block = self._allocator.allocate_immutable(
|
||||
prev_block, token_ids=block_token_ids, device=device)
|
||||
else:
|
||||
# Else, partially fill a mutable block with token ids.
|
||||
prev_block = self._allocator.allocate_mutable(
|
||||
prev_block=prev_block, device=device)
|
||||
prev_block.append_token_ids(block_token_ids)
|
||||
blocks.append(prev_block)
|
||||
|
||||
return blocks
|
||||
|
||||
def _get_all_token_ids(self) -> List[int]:
|
||||
# NOTE: This function is O(seq_len); use sparingly.
|
||||
token_ids: List[int] = []
|
||||
|
||||
if not self._is_allocated:
|
||||
return token_ids
|
||||
|
||||
for block in self._blocks:
|
||||
token_ids.extend(block.token_ids)
|
||||
|
||||
return token_ids
|
||||
|
||||
@property
|
||||
def _is_allocated(self) -> bool:
|
||||
return len(self._blocks) > 0
|
||||
|
||||
@property
|
||||
def _num_empty_slots(self) -> int:
|
||||
assert self._is_allocated
|
||||
return len(self._blocks) * self._block_size - self._num_full_slots
|
||||
|
||||
@property
|
||||
def num_full_slots(self) -> int:
|
||||
"""Returns the total number of tokens currently stored in the
|
||||
BlockTable.
|
||||
|
||||
Returns:
|
||||
int: The total number of tokens currently stored in the BlockTable.
|
||||
"""
|
||||
return self._num_full_slots
|
||||
|
||||
def get_num_blocks_touched_by_append_slots(
|
||||
self, token_ids: List[int], num_lookahead_slots: int) -> int:
|
||||
"""Determine how many blocks will be "touched" by appending the token
|
||||
ids.
|
||||
|
||||
This is required for the scheduler to determine whether a sequence can
|
||||
continue generation, or if it must be preempted.
|
||||
"""
|
||||
|
||||
all_token_ids = token_ids + [-1] * num_lookahead_slots
|
||||
token_blocks = self._chunk_token_blocks_for_append(all_token_ids)
|
||||
return len(token_blocks)
|
||||
|
||||
def _chunk_token_blocks_for_append(
|
||||
self, token_ids: List[int]) -> List[List[int]]:
|
||||
"""Split the token ids into block-sized chunks so they can be easily
|
||||
appended to blocks. The first such "token block" may have less token ids
|
||||
than the block size, since the last allocated block may be partially
|
||||
full.
|
||||
"""
|
||||
first_chunk_size = self._block_size - (self._num_full_slots %
|
||||
self._block_size)
|
||||
token_blocks = [token_ids[:first_chunk_size]] + chunk_list(
|
||||
token_ids[first_chunk_size:], self._block_size)
|
||||
return token_blocks
|
||||
199
vllm/core/block/common.py
Normal file
199
vllm/core/block/common.py
Normal file
@@ -0,0 +1,199 @@
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Iterable, List, Optional, Protocol
|
||||
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator
|
||||
|
||||
BlockId = int
|
||||
RefCount = int
|
||||
|
||||
|
||||
class RefCounterProtocol(Protocol):
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
raise NotImplementedError
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
raise NotImplementedError
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RefCounter(RefCounterProtocol):
|
||||
"""A class for managing reference counts for a set of block indices.
|
||||
|
||||
The RefCounter class maintains a dictionary that maps block indices to their
|
||||
corresponding reference counts. It provides methods to increment, decrement,
|
||||
and retrieve the reference count for a given block index.
|
||||
|
||||
Args:
|
||||
all_block_indices (Iterable[BlockId]): An iterable of block indices
|
||||
to initialize the reference counter with.
|
||||
"""
|
||||
|
||||
def __init__(self, all_block_indices: Iterable[BlockId]):
|
||||
deduped = set(all_block_indices)
|
||||
self._refcounts: Dict[BlockId,
|
||||
RefCount] = {index: 0
|
||||
for index in deduped}
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
pre_incr_refcount = self._refcounts[block_id]
|
||||
|
||||
assert pre_incr_refcount >= 0
|
||||
|
||||
post_incr_refcount = pre_incr_refcount + 1
|
||||
self._refcounts[block_id] = post_incr_refcount
|
||||
return post_incr_refcount
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
refcount = self._refcounts[block_id]
|
||||
|
||||
assert refcount > 0
|
||||
refcount -= 1
|
||||
|
||||
self._refcounts[block_id] = refcount
|
||||
|
||||
return refcount
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
assert block_id in self._refcounts
|
||||
return self._refcounts[block_id]
|
||||
|
||||
def as_readonly(self) -> "ReadOnlyRefCounter":
|
||||
return ReadOnlyRefCounter(self)
|
||||
|
||||
|
||||
class ReadOnlyRefCounter(RefCounterProtocol):
|
||||
"""A read-only view of the RefCounter class.
|
||||
|
||||
The ReadOnlyRefCounter class provides a read-only interface to access the
|
||||
reference counts maintained by a RefCounter instance. It does not allow
|
||||
modifications to the reference counts.
|
||||
|
||||
Args:
|
||||
refcounter (RefCounter): The RefCounter instance to create a read-only
|
||||
view for.
|
||||
"""
|
||||
|
||||
def __init__(self, refcounter: RefCounter):
|
||||
self._refcounter = refcounter
|
||||
|
||||
def incr(self, block_id: BlockId) -> RefCount:
|
||||
raise ValueError("Incr not allowed")
|
||||
|
||||
def decr(self, block_id: BlockId) -> RefCount:
|
||||
raise ValueError("Decr not allowed")
|
||||
|
||||
def get(self, block_id: BlockId) -> RefCount:
|
||||
return self._refcounter.get(block_id)
|
||||
|
||||
|
||||
class CopyOnWriteTracker:
|
||||
"""A class for tracking and managing copy-on-write operations for blocks.
|
||||
|
||||
The CopyOnWriteTracker class maintains a mapping of source block indices to
|
||||
their corresponding copy-on-write destination block indices. It works in
|
||||
conjunction with a RefCounter and a BlockAllocator to handle reference
|
||||
counting and block allocation.
|
||||
|
||||
Args:
|
||||
refcounter (RefCounter): The reference counter used to track block
|
||||
reference counts.
|
||||
allocator (BlockAllocator): The block allocator used to allocate and
|
||||
free blocks.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
refcounter: RefCounterProtocol,
|
||||
allocator: BlockAllocator,
|
||||
):
|
||||
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
|
||||
self._refcounter = refcounter
|
||||
self._allocator = allocator
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
This method checks the reference count of the given block. If the
|
||||
reference count is greater than 1, indicating that the block is shared,
|
||||
a copy-on-write operation is performed. The original block is freed,
|
||||
and a new block is allocated with the same content. The new block index
|
||||
is returned.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
Optional[BlockId]: The block index of the new block if a copy-on
|
||||
-write operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
block_id = block.block_id
|
||||
if block_id is None:
|
||||
return block_id
|
||||
|
||||
refcount = self._refcounter.get(block_id)
|
||||
assert refcount != 0
|
||||
if refcount > 1:
|
||||
src_block_id = block_id
|
||||
|
||||
# Decrement refcount of the old block.
|
||||
self._allocator.free(block)
|
||||
|
||||
# Allocate a fresh new block.
|
||||
block_id = self._allocator.allocate_mutable(
|
||||
prev_block=block.prev_block).block_id
|
||||
|
||||
# Track src/dst copy.
|
||||
assert src_block_id is not None
|
||||
assert block_id is not None
|
||||
self._copy_on_writes[src_block_id].append(block_id)
|
||||
|
||||
return block_id
|
||||
|
||||
def clear_cows(self) -> Dict[BlockId, List[BlockId]]:
|
||||
"""Clears the copy-on-write tracking information and returns the current
|
||||
state.
|
||||
|
||||
This method returns a dictionary mapping source block indices to lists
|
||||
of destination block indices for the current copy-on-write operations.
|
||||
It then clears the internal tracking information.
|
||||
|
||||
Returns:
|
||||
Dict[BlockId, List[BlockId]]: A dictionary mapping source
|
||||
block indices to lists of destination block indices for the
|
||||
current copy-on-write operations.
|
||||
"""
|
||||
cows = dict(self._copy_on_writes)
|
||||
self._copy_on_writes.clear()
|
||||
return cows
|
||||
|
||||
|
||||
def get_all_blocks_recursively(last_block: Block) -> List[Block]:
|
||||
"""Retrieves all the blocks in a sequence starting from the last block.
|
||||
|
||||
This function recursively traverses the sequence of blocks in reverse order,
|
||||
starting from the given last block, and returns a list of all the blocks in
|
||||
the sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A list of all the blocks in the sequence, in the order they
|
||||
appear.
|
||||
"""
|
||||
|
||||
def recurse(block: Block, lst: List[Block]) -> None:
|
||||
if block.prev_block is not None:
|
||||
recurse(block.prev_block, lst)
|
||||
lst.append(block)
|
||||
|
||||
all_blocks: List[Block] = []
|
||||
recurse(last_block, all_blocks)
|
||||
return all_blocks
|
||||
228
vllm/core/block/cpu_gpu_block_allocator.py
Normal file
228
vllm/core/block/cpu_gpu_block_allocator.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from typing import Dict, FrozenSet, List, Optional
|
||||
|
||||
from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
|
||||
DeviceAwareBlockAllocator)
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
"""A block allocator that can allocate blocks on both CPU and GPU memory.
|
||||
|
||||
This class implements the `DeviceAwareBlockAllocator` interface and provides
|
||||
functionality for allocating and managing blocks of memory on both CPU and
|
||||
GPU devices.
|
||||
|
||||
The `CpuGpuBlockAllocator` maintains separate memory pools for CPU and GPU
|
||||
blocks, and allows for allocation, deallocation, forking, and swapping of
|
||||
blocks across these memory pools.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
allocator_type: str,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
) -> DeviceAwareBlockAllocator:
|
||||
"""Creates a CpuGpuBlockAllocator instance with the specified
|
||||
configuration.
|
||||
|
||||
This static method creates and returns a CpuGpuBlockAllocator instance
|
||||
based on the provided parameters. It initializes the CPU and GPU block
|
||||
allocators with the specified number of blocks, block size, and
|
||||
allocator type.
|
||||
|
||||
Args:
|
||||
allocator_type (str): The type of block allocator to use for CPU
|
||||
and GPU blocks. Currently supported values are "naive" and
|
||||
"prefix_caching".
|
||||
num_gpu_blocks (int): The number of blocks to allocate for GPU
|
||||
memory.
|
||||
num_cpu_blocks (int): The number of blocks to allocate for CPU
|
||||
memory.
|
||||
block_size (int): The size of each block in number of tokens.
|
||||
|
||||
Returns:
|
||||
DeviceAwareBlockAllocator: A CpuGpuBlockAllocator instance with the
|
||||
specified configuration.
|
||||
|
||||
Notes:
|
||||
- The block IDs are assigned contiguously, with GPU block IDs coming
|
||||
before CPU block IDs.
|
||||
"""
|
||||
block_ids = list(range(num_gpu_blocks + num_cpu_blocks))
|
||||
gpu_block_ids = block_ids[:num_gpu_blocks]
|
||||
cpu_block_ids = block_ids[num_gpu_blocks:]
|
||||
|
||||
if allocator_type == "naive":
|
||||
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
elif allocator_type == "prefix_caching":
|
||||
gpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
)
|
||||
|
||||
cpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown allocator type {allocator_type=}")
|
||||
|
||||
return CpuGpuBlockAllocator(
|
||||
cpu_block_allocator=cpu_allocator,
|
||||
gpu_block_allocator=gpu_allocator,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_block_allocator: BlockAllocator,
|
||||
gpu_block_allocator: BlockAllocator,
|
||||
):
|
||||
assert not (
|
||||
cpu_block_allocator.all_block_ids
|
||||
& gpu_block_allocator.all_block_ids
|
||||
), "cpu and gpu block allocators can't have intersection of block ids"
|
||||
|
||||
self._allocators = {
|
||||
Device.CPU: cpu_block_allocator,
|
||||
Device.GPU: gpu_block_allocator,
|
||||
}
|
||||
|
||||
self._block_ids_to_allocator: Dict[int, BlockAllocator] = {}
|
||||
for _, allocator in self._allocators.items():
|
||||
for block_id in allocator.all_block_ids:
|
||||
self._block_ids_to_allocator[block_id] = allocator
|
||||
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
"""Allocates a new mutable block on the specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block to in the sequence.
|
||||
Used for prefix hashing.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
return self._allocators[device].allocate_mutable(prev_block)
|
||||
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int], device: Device) -> Block:
|
||||
"""Allocates a new immutable block with the provided token IDs on the
|
||||
specified device.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
Used for prefix hashing.
|
||||
token_ids (List[int]): The list of token IDs to be stored in the new
|
||||
block.
|
||||
device (Device): The device on which to allocate the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block containing the provided
|
||||
token IDs.
|
||||
"""
|
||||
return self._allocators[device].allocate_immutable(
|
||||
prev_block, token_ids)
|
||||
|
||||
def free(self, block: Block) -> None:
|
||||
"""Frees the memory occupied by the given block.
|
||||
|
||||
Args:
|
||||
block (Block): The block to be freed.
|
||||
"""
|
||||
block_id = block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
return allocator.free(block)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: A new list of blocks that shares the same memory as the
|
||||
original sequence.
|
||||
"""
|
||||
block_id = last_block.block_id
|
||||
assert block_id is not None
|
||||
allocator = self._block_ids_to_allocator[block_id]
|
||||
return allocator.fork(last_block)
|
||||
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
"""Returns the number of free blocks available on the specified device.
|
||||
|
||||
Args:
|
||||
device (Device): The device for which to query the number of free
|
||||
blocks. AssertionError is raised if None is passed.
|
||||
|
||||
Returns:
|
||||
int: The number of free blocks available on the specified device.
|
||||
"""
|
||||
return self._allocators[device].get_num_free_blocks()
|
||||
|
||||
def get_num_total_blocks(self, device: Device) -> int:
|
||||
return self._allocators[device].get_num_total_blocks()
|
||||
|
||||
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
||||
"""Clears the copy-on-write (CoW) state and returns the mapping of
|
||||
source to destination block IDs.
|
||||
|
||||
Returns:
|
||||
Dict[int, List[int]]: A dictionary mapping source block IDs to lists
|
||||
of destination block IDs.
|
||||
"""
|
||||
# CoW only supported on GPU
|
||||
device = Device.GPU
|
||||
return self._allocators[device].clear_copy_on_writes()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_accessed(block_ids, now)
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as accessed, only use for prefix caching."""
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].mark_blocks_as_computed(block_ids)
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
# Prefix caching only supported on GPU.
|
||||
device = Device.GPU
|
||||
return self._allocators[device].get_common_computed_block_ids(
|
||||
seq_block_ids)
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return frozenset(self._block_ids_to_allocator.keys())
|
||||
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
raise NotImplementedError
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
raise NotImplementedError
|
||||
205
vllm/core/block/interfaces.py
Normal file
205
vllm/core/block/interfaces.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, FrozenSet, List, Optional, Protocol
|
||||
|
||||
from vllm.utils import Device
|
||||
|
||||
BlockId = int
|
||||
|
||||
|
||||
class Block(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def block_id(self) -> Optional[int]:
|
||||
pass
|
||||
|
||||
@block_id.setter
|
||||
@abstractmethod
|
||||
def block_id(self, value: Optional[int]) -> None:
|
||||
"""NOTE: Do not use this API outside Block."""
|
||||
self._block_id = value
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def token_ids(self) -> List[int]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_empty_slots(self) -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_full(self) -> bool:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def prev_block(self) -> Optional["Block"]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def computed(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@computed.setter
|
||||
@abstractmethod
|
||||
def computed(self, value) -> bool:
|
||||
"""Should be only used by PrefixCacingAllocator"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def last_accessed(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
@last_accessed.setter
|
||||
@abstractmethod
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
raise NotImplementedError
|
||||
|
||||
class Factory(Protocol):
|
||||
|
||||
@abstractmethod
|
||||
def __call__(
|
||||
self,
|
||||
prev_block: Optional["Block"],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: "BlockAllocator",
|
||||
block_id: Optional[int] = None,
|
||||
) -> "Block":
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def content_hash(self) -> Optional[int]:
|
||||
"""Return the content-based hash of the current block, or None if it is
|
||||
not yet defined or not supported.
|
||||
|
||||
For the content-based hash to be defined, the current block must be
|
||||
full.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class BlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int]) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: Block) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional["BlockId"]:
|
||||
"""NOTE: This should not be used besides Block"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
"""NOTE: This should not be used besides Block"""
|
||||
pass
|
||||
|
||||
class NoFreeBlocksError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceAwareBlockAllocator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def allocate_mutable(self, prev_block: Optional[Block],
|
||||
device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_immutable(self, prev_block: Optional[Block],
|
||||
token_ids: List[int], device: Device) -> Block:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self, device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self, device: Device) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: Block) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear_copy_on_writes(self) -> Dict[int, List[int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
pass
|
||||
318
vllm/core/block/naive_block.py
Normal file
318
vllm/core/block/naive_block.py
Normal file
@@ -0,0 +1,318 @@
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional, Set
|
||||
|
||||
from vllm.core.block.common import (CopyOnWriteTracker, RefCounter,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
|
||||
Refcount = int
|
||||
|
||||
|
||||
class NaiveBlockAllocator(BlockAllocator):
|
||||
"""A simple block allocator that manages blocks of memory without prefix
|
||||
caching.
|
||||
|
||||
Args:
|
||||
create_block (Block.Factory): A factory function for creating new
|
||||
blocks. This is used when a NaiveBlockAllocator is composed within
|
||||
a prefix caching allocator -- the naive block allocator must
|
||||
construct prefix caching blocks (but shouldn't know anything else
|
||||
about them).
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids (Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
create_block: Block.Factory,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
):
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._free_block_indices: Set[BlockId] = set(block_ids)
|
||||
self._all_block_indices = frozenset(block_ids)
|
||||
assert len(self._all_block_indices) == num_blocks
|
||||
|
||||
self._refcounter = RefCounter(
|
||||
all_block_indices=self._free_block_indices)
|
||||
self._create_block = create_block
|
||||
self._block_size = block_size
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly(),
|
||||
allocator=self,
|
||||
)
|
||||
|
||||
def allocate_immutable(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates a new immutable block with the given token IDs, linked to
|
||||
the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the new block.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated immutable block.
|
||||
"""
|
||||
assert device is None
|
||||
block = self.allocate_mutable(prev_block=prev_block)
|
||||
block.append_token_ids(token_ids)
|
||||
return block
|
||||
|
||||
def allocate_mutable(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates a new mutable block, linked to the previous block.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence. If
|
||||
None, then the block to be allocated is the first block in the
|
||||
sequence.
|
||||
|
||||
Returns:
|
||||
Block: The newly allocated mutable block.
|
||||
"""
|
||||
assert device is None
|
||||
block_id = self._allocate_new_block_id()
|
||||
return self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_id=block_id,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
)
|
||||
|
||||
def free(self, block: Block) -> None:
|
||||
assert block.block_id is not None
|
||||
self._free_block_id(block.block_id)
|
||||
|
||||
# Mark the block as having no allocation.
|
||||
block.block_id = None
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks = []
|
||||
prev_block = None
|
||||
for block in source_blocks:
|
||||
|
||||
# Increment refcount for each block.
|
||||
assert block.block_id is not None
|
||||
refcount = self._refcounter.incr(block.block_id)
|
||||
assert refcount != 1, "can't fork free'd block"
|
||||
|
||||
forked_blocks.append(
|
||||
self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_id=block.block_id,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
))
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self._free_block_indices)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return len(self._all_block_indices)
|
||||
|
||||
def _allocate_new_block_id(self) -> BlockId:
|
||||
if not self._free_block_indices:
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
block_id = next(iter(self._free_block_indices))
|
||||
self._refcounter.incr(block_id)
|
||||
self._free_block_indices.remove(block_id)
|
||||
return block_id
|
||||
|
||||
def _free_block_id(self, block_id: BlockId) -> None:
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount == 0:
|
||||
self._free_block_indices.add(block_id)
|
||||
|
||||
@property
|
||||
def refcounter(self):
|
||||
return self._refcounter
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._all_block_indices
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
Optional[BlockId]: The block index of the new block if a copy-on
|
||||
-write operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
return self._cow_tracker.cow_block_if_not_appendable(block)
|
||||
|
||||
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
Dict[BlockId, List[BlockId]]: A dictionary mapping source
|
||||
block indices to lists of destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as computed, used in prefix caching.
|
||||
|
||||
Since the naive allocator does not implement prefix caching, we do
|
||||
nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Determine blocks that can be skipped in prefill.
|
||||
|
||||
Since the naive allocator does not support prefix caching, always return
|
||||
an empty list.
|
||||
"""
|
||||
return []
|
||||
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class NaiveBlock(Block):
|
||||
"""An implementation of the Block class that does not support prefix
|
||||
caching.
|
||||
|
||||
The NaiveBlock class represents a block of token IDs with a fixed size. It
|
||||
provides methods for appending token IDs to the block and manages copy-on
|
||||
-write operations when necessary.
|
||||
|
||||
Args:
|
||||
prev_block (Block): The previous block in the sequence.
|
||||
token_ids (List[int]): The initial token IDs to be stored in the block.
|
||||
block_size (int): The maximum number of token IDs that can be stored in
|
||||
the block.
|
||||
allocator (BlockAllocator): The block allocator associated with this
|
||||
block.
|
||||
block_id (Optional[int], optional): The physical block index
|
||||
of this block. Defaults to None, which means no allocation has been
|
||||
made.
|
||||
_cow_target (Optional[Block], optional): The copy-on-write target block.
|
||||
If not provided, it defaults to self.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
_cow_target: Optional[Block] = None):
|
||||
self._token_ids: List[int] = []
|
||||
self._block_size = block_size
|
||||
self._prev_block = prev_block
|
||||
self._block_id = block_id
|
||||
self._allocator = allocator
|
||||
self._cow_target = _cow_target if _cow_target is not None else self
|
||||
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
"""Appends the given token IDs to the block, instructing the allocator
|
||||
to perform a copy-on-write if necessary.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The token IDs to be appended to the block.
|
||||
"""
|
||||
self._append_token_ids_no_cow(token_ids)
|
||||
|
||||
if self._block_id is not None:
|
||||
self._block_id = (self._allocator.cow_block_if_not_appendable(
|
||||
self._cow_target))
|
||||
|
||||
def _append_token_ids_no_cow(self, token_ids: List[int]) -> None:
|
||||
assert self.num_empty_slots >= len(token_ids)
|
||||
self._token_ids.extend(token_ids)
|
||||
|
||||
@property
|
||||
def computed(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@computed.setter
|
||||
def computed(self, value) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def last_accessed(self) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
@last_accessed.setter
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def block_id(self) -> Optional[int]:
|
||||
return self._block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value: Optional[int]) -> None:
|
||||
self._block_id = value
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return self.num_empty_slots == 0
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> int:
|
||||
return self._block_size - len(self._token_ids)
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[int]:
|
||||
return self._token_ids
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
|
||||
@property
|
||||
def prev_block(self) -> Optional["Block"]:
|
||||
return self._prev_block
|
||||
|
||||
@property
|
||||
def content_hash(self) -> Optional[int]:
|
||||
return None
|
||||
606
vllm/core/block/prefix_caching_block.py
Normal file
606
vllm/core/block/prefix_caching_block.py
Normal file
@@ -0,0 +1,606 @@
|
||||
"""Token blocks."""
|
||||
from itertools import takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, FrozenSet, Iterable, List, Optional
|
||||
|
||||
from vllm.core.block.common import (CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
from vllm.core.evictor_v2 import EvictionPolicy, Evictor, make_evictor
|
||||
|
||||
PrefixHash = int
|
||||
|
||||
# By default, we init our block access time as _DEFAULT_LAST_ACCESSED_TIME
|
||||
# so that if we find one block is still hold _DEFAULT_LAST_ACCESSED_TIME,
|
||||
# then we know this block hasn't been accessed yet.
|
||||
_DEFAULT_LAST_ACCESSED_TIME = -1
|
||||
|
||||
|
||||
class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
"""A block allocator that implements prefix caching.
|
||||
|
||||
The PrefixCachingBlockAllocator maintains a cache of blocks based on their
|
||||
content hash. It reuses blocks with the same content hash to avoid redundant
|
||||
memory allocation. The allocator also supports copy-on-write operations.
|
||||
|
||||
Args:
|
||||
num_blocks (int): The total number of blocks to manage.
|
||||
block_size (int): The size of each block in tokens.
|
||||
block_ids(Optional[Iterable[int]], optional): An optional iterable of
|
||||
block IDs. If not provided, block IDs will be assigned sequentially
|
||||
from 0 to num_blocks - 1.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
||||
):
|
||||
# A mapping of prefix hash to block index. All blocks which have a
|
||||
# prefix hash will be in this dict, even if they have refcount 0.
|
||||
self._cached_blocks: Dict[PrefixHash, BlockId] = {}
|
||||
|
||||
# A mapping of blockId to Block to track those cached blocks
|
||||
self._blocks: Dict[BlockId, Block] = {}
|
||||
|
||||
# An allocator for blocks that do not have prefix hashes.
|
||||
self._hashless_allocator = NaiveBlockAllocator(
|
||||
create_block=self._create_block, # type: ignore
|
||||
num_blocks=num_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=block_ids,
|
||||
)
|
||||
|
||||
self._block_size = block_size
|
||||
|
||||
# Evitor used to maintain how we want to handle those computed blocks
|
||||
# if we find memory pressure is high.
|
||||
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||
|
||||
# We share the refcounter between allocators. This allows us to promote
|
||||
# blocks originally allocated in the hashless allocator to immutable
|
||||
# blocks.
|
||||
self._refcounter = self._hashless_allocator.refcounter
|
||||
|
||||
self._cow_tracker = CopyOnWriteTracker(
|
||||
refcounter=self._refcounter.as_readonly(),
|
||||
allocator=self,
|
||||
)
|
||||
|
||||
# Implements Block.Factory.
|
||||
def _create_block(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
computed: bool = False,
|
||||
) -> Block:
|
||||
# Bind block to self.
|
||||
allocator = self
|
||||
|
||||
return PrefixCachingBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
prefix_caching_allocator=allocator,
|
||||
computed=computed,
|
||||
)
|
||||
|
||||
def allocate_immutable(self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates an immutable block with the given token IDs, reusing cached
|
||||
blocks if possible.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[Block]): The previous block in the sequence.
|
||||
token_ids (List[int]): The token IDs to be stored in the block.
|
||||
|
||||
Returns:
|
||||
Block: The allocated immutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
block = self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
)
|
||||
assert block.content_hash is not None
|
||||
|
||||
cached_block_id = self._cached_blocks.get(block.content_hash, None)
|
||||
if cached_block_id is not None:
|
||||
block.block_id = cached_block_id
|
||||
self._incr_refcount_cached_block(block, block.block_id)
|
||||
return block
|
||||
|
||||
block = self.allocate_mutable(prev_block)
|
||||
block.append_token_ids(token_ids)
|
||||
assert block.content_hash is not None
|
||||
|
||||
return block
|
||||
|
||||
def allocate_mutable(self,
|
||||
prev_block: Optional[Block],
|
||||
device: Optional[Device] = None) -> Block:
|
||||
"""Allocates a mutable block. If there are no free blocks, this will
|
||||
evict unused cached blocks.
|
||||
|
||||
Args:
|
||||
prev_block (Block): The previous block in the sequence.
|
||||
None is not allowed unlike it is super class.
|
||||
|
||||
Returns:
|
||||
Block: The allocated mutable block.
|
||||
"""
|
||||
assert device is None
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
try:
|
||||
block = self._hashless_allocator.allocate_mutable(
|
||||
prev_block=prev_block)
|
||||
|
||||
assert block.block_id not in self._blocks
|
||||
assert block.block_id is not None
|
||||
self._blocks[block.block_id] = block
|
||||
return block
|
||||
except BlockAllocator.NoFreeBlocksError:
|
||||
# We must check the unused cached blocks before raising OOM.
|
||||
pass
|
||||
|
||||
# If the evictor has blocks available for eviction, evict a block
|
||||
# and return it.
|
||||
if self.evictor.num_blocks > 0:
|
||||
block_id, content_hash_to_evict = self.evictor.evict()
|
||||
|
||||
# Here we may have scenario that several blocks have
|
||||
# the same content hash, but due to the latter coming block
|
||||
# is coming from mutable to immutable path, their physical
|
||||
# block is added into evictor.
|
||||
# However in this case, we shall not pop the _cached_blocks,
|
||||
# as the same content is still used by others, which means
|
||||
# we need to check ref before decide to pop the list.
|
||||
|
||||
_block_id = self._cached_blocks[content_hash_to_evict]
|
||||
refcount = self._refcounter.get(_block_id)
|
||||
if refcount == 1:
|
||||
self._cached_blocks.pop(content_hash_to_evict)
|
||||
assert _block_id == block_id
|
||||
|
||||
self._refcounter.incr(block_id)
|
||||
|
||||
# the block comes from evictor already contain computed result
|
||||
block = self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=[],
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
block_id=block_id,
|
||||
computed=True,
|
||||
)
|
||||
assert block.content_hash is None
|
||||
|
||||
assert block.block_id not in self._blocks
|
||||
assert block.block_id is not None
|
||||
self._blocks[block.block_id] = block
|
||||
return block
|
||||
|
||||
# No block available in hashless allocator, nor in unused cache blocks.
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
def _incr_refcount_cached_block(self, block: Block,
|
||||
block_id: BlockId) -> None:
|
||||
# since block is already computed, mark it
|
||||
block.computed = True
|
||||
|
||||
refcount = self._refcounter.incr(block_id)
|
||||
if refcount == 1:
|
||||
# if block get referred, then it shall not be in evictor
|
||||
# and put it into _blocks for tracking
|
||||
if block_id in self.evictor:
|
||||
self.evictor.remove(block_id)
|
||||
self._blocks[block_id] = block
|
||||
|
||||
def free(self, block: Block) -> None:
|
||||
"""Decrement the refcount of the block. If the decremented refcount is
|
||||
zero, store the block in the freelist.
|
||||
|
||||
If the block has a content hash (meaning it is immutable), then we will
|
||||
keep the block around in case future allocations require it.
|
||||
"""
|
||||
assert (block.block_id
|
||||
is not None), "freeing unallocated block is undefined"
|
||||
|
||||
self._free_block_id_for_block(block.block_id, block)
|
||||
|
||||
block.block_id = None
|
||||
|
||||
def _free_block_id_for_block(self, block_id: BlockId,
|
||||
block: Block) -> None:
|
||||
assert isinstance(block, PrefixCachingBlock)
|
||||
|
||||
if block.content_hash is None:
|
||||
refcount = self._refcounter.get(block_id)
|
||||
# We have fork case where block would get more than one ref,
|
||||
# so we cannot free it from tracking if ref cnt large than 1
|
||||
if refcount <= 1:
|
||||
assert block.block_id is not None
|
||||
del self._blocks[block.block_id]
|
||||
return self._hashless_allocator.free(block)
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
|
||||
# If no longer used, add the block to the evictor.
|
||||
if refcount == 0:
|
||||
assert block.content_hash in self._cached_blocks
|
||||
assert block.block_id is not None
|
||||
del self._blocks[block.block_id]
|
||||
self.evictor.add(block.block_id, block.content_hash,
|
||||
block.num_tokens_total, block.last_accessed)
|
||||
|
||||
def fork(self, last_block: Block) -> List[Block]:
|
||||
"""Creates a new sequence of blocks that shares the same underlying
|
||||
memory as the original sequence.
|
||||
|
||||
Args:
|
||||
last_block (Block): The last block in the original sequence.
|
||||
|
||||
Returns:
|
||||
List[Block]: The new sequence of blocks that shares the same memory
|
||||
as the original sequence.
|
||||
"""
|
||||
source_blocks = get_all_blocks_recursively(last_block)
|
||||
|
||||
forked_blocks = []
|
||||
prev_block = None
|
||||
for block in source_blocks:
|
||||
refcount = self._refcounter.incr(block.block_id)
|
||||
assert refcount != 1, "can't fork free'd block"
|
||||
|
||||
forked_blocks.append(
|
||||
self._create_block(
|
||||
prev_block=prev_block,
|
||||
token_ids=block.token_ids,
|
||||
block_id=block.block_id,
|
||||
block_size=self._block_size,
|
||||
allocator=self,
|
||||
))
|
||||
prev_block = forked_blocks[-1]
|
||||
|
||||
return forked_blocks
|
||||
|
||||
def get_num_free_blocks(self, device: Optional[Device] = None) -> int:
|
||||
assert device is None
|
||||
# The number of free blocks is the number of hashless free blocks
|
||||
# plus the number of blocks evictor could free from its list.
|
||||
return self._hashless_allocator.get_num_free_blocks(
|
||||
) + self.evictor.num_blocks
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self._hashless_allocator.get_num_total_blocks()
|
||||
|
||||
@property
|
||||
def all_block_ids(self) -> FrozenSet[int]:
|
||||
return self._hashless_allocator.all_block_ids
|
||||
|
||||
def promote_to_immutable_block(self, block: Block) -> BlockId:
|
||||
"""Once a mutable block is full, it can be promoted to an immutable
|
||||
block. This means that its content can be referenced by future blocks
|
||||
having the same prefix.
|
||||
|
||||
Note that if we already have a cached block with the same content, we
|
||||
will replace the newly-promoted block's mapping with the existing cached
|
||||
block.
|
||||
|
||||
Args:
|
||||
block: The mutable block to be promoted.
|
||||
|
||||
Returns:
|
||||
BlockId: Either the original block index, or the block index of
|
||||
the previously cached block matching the same content.
|
||||
"""
|
||||
assert block.content_hash is not None
|
||||
assert block.block_id is not None
|
||||
assert self._refcounter.get(block.block_id) > 0
|
||||
|
||||
# If the content hash does not have a corresponding cached block,
|
||||
# set this block as the cached block.
|
||||
if block.content_hash not in self._cached_blocks:
|
||||
self._cached_blocks[block.content_hash] = block.block_id
|
||||
else:
|
||||
self._free_block_id_for_block(block.block_id, block)
|
||||
self._incr_refcount_cached_block(
|
||||
block, self._cached_blocks[block.content_hash])
|
||||
|
||||
return self._cached_blocks[block.content_hash]
|
||||
|
||||
def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
|
||||
"""Performs a copy-on-write operation on the given block if it is not
|
||||
appendable.
|
||||
|
||||
Args:
|
||||
block (Block): The block to check for copy-on-write.
|
||||
|
||||
Returns:
|
||||
Optional[BlockId]: The block index of the new block if a copy-on
|
||||
-write operation was performed, or the original block index if
|
||||
no copy-on-write was necessary.
|
||||
"""
|
||||
return self._cow_tracker.cow_block_if_not_appendable(block)
|
||||
|
||||
def clear_copy_on_writes(self) -> Dict[BlockId, List[BlockId]]:
|
||||
"""Returns the copy-on-write source->destination mapping and clears it.
|
||||
|
||||
Returns:
|
||||
Dict[BlockId, List[BlockId]]: A dictionary mapping source
|
||||
block indices to lists of destination block indices.
|
||||
"""
|
||||
return self._cow_tracker.clear_cows()
|
||||
|
||||
def mark_blocks_as_accessed(self, block_ids: List[int],
|
||||
now: float) -> None:
|
||||
"""Mark blocks as accessed, used in prefix caching.
|
||||
|
||||
If the block is added into evictor, we need to update corresponding
|
||||
info in evictor's metadata.
|
||||
"""
|
||||
|
||||
for block_id in block_ids:
|
||||
if block_id in self._blocks:
|
||||
self._blocks[block_id].last_accessed = now
|
||||
elif block_id in self.evictor:
|
||||
self.evictor.update(block_id, now)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Mark block as accessed which is not belonged to GPU")
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
"""Mark blocks as computed, used in prefix caching."""
|
||||
|
||||
for block_id in block_ids:
|
||||
if block_id in self._blocks:
|
||||
# only those full block is valid for prefix caching
|
||||
if self._blocks[block_id].is_full:
|
||||
self._blocks[block_id].computed = True
|
||||
elif block_id not in self.evictor:
|
||||
raise ValueError(f"Mark {block_id=} as computed which "
|
||||
"is not belonged to GPU")
|
||||
|
||||
def block_is_computed(self, block_id: int) -> bool:
|
||||
if block_id in self._blocks:
|
||||
return self._blocks[block_id].computed
|
||||
else:
|
||||
return block_id in self.evictor
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seq_block_ids: List[List[int]]) -> List[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Only those blocks that are immutable and already be marked
|
||||
compyted would be taken consideration.
|
||||
"""
|
||||
|
||||
# NOTE We exclude the last block to avoid the case where the entire
|
||||
# prompt is cached. This would cause erroneous behavior in model
|
||||
# runner.
|
||||
|
||||
ids_list = [
|
||||
list(
|
||||
takewhile(lambda block_id: self.block_is_computed(block_id),
|
||||
seq[:-1])) for seq in seq_block_ids
|
||||
]
|
||||
# It returns a list of int although type annotation says list of string.
|
||||
return commonprefix([
|
||||
ids for ids in ids_list # type: ignore
|
||||
if ids != []
|
||||
])
|
||||
|
||||
|
||||
class PrefixCachingBlock(Block):
|
||||
"""A block implementation that supports prefix caching.
|
||||
|
||||
The PrefixCachingBlock class represents a block of token IDs with prefix
|
||||
caching capabilities. It wraps a NaiveBlock internally and provides
|
||||
additional functionality for content hashing and promoting immutable blocks
|
||||
with the prefix caching allocator.
|
||||
|
||||
Args:
|
||||
prev_block (Optional[PrefixCachingBlock]): The previous block in the
|
||||
sequence.
|
||||
token_ids (List[int]): The initial token IDs to be stored in the block.
|
||||
block_size (int): The maximum number of token IDs that can be stored in
|
||||
the block.
|
||||
prefix_caching_allocator (BlockAllocator): The prefix
|
||||
caching block allocator associated with this block.
|
||||
block_id (Optional[int], optional): The physical block index
|
||||
of this block. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
token_ids: List[int],
|
||||
block_size: int,
|
||||
prefix_caching_allocator: BlockAllocator,
|
||||
block_id: Optional[int] = None,
|
||||
computed: bool = False,
|
||||
):
|
||||
assert isinstance(prefix_caching_allocator,
|
||||
PrefixCachingBlockAllocator), (
|
||||
"Currently this class is only tested with "
|
||||
"PrefixCachingBlockAllocator.")
|
||||
assert_prefix_caching_block_or_none(prev_block)
|
||||
|
||||
self._prev_block = prev_block
|
||||
self._cached_content_hash: Optional[int] = None
|
||||
self._cached_num_tokens_total: Optional[int] = None
|
||||
self._prefix_caching_allocator = prefix_caching_allocator
|
||||
self._last_accessed: float = _DEFAULT_LAST_ACCESSED_TIME
|
||||
self._computed = computed
|
||||
|
||||
self._block = NaiveBlock(
|
||||
prev_block=prev_block,
|
||||
token_ids=token_ids,
|
||||
block_size=block_size,
|
||||
block_id=block_id,
|
||||
allocator=prefix_caching_allocator,
|
||||
_cow_target=self,
|
||||
)
|
||||
|
||||
@property
|
||||
def computed(self) -> bool:
|
||||
return self._computed
|
||||
|
||||
@computed.setter
|
||||
def computed(self, value) -> None:
|
||||
self._computed = value
|
||||
|
||||
@property
|
||||
def last_accessed(self) -> float:
|
||||
return self._last_accessed
|
||||
|
||||
@last_accessed.setter
|
||||
def last_accessed(self, last_accessed_ts: float):
|
||||
self._last_accessed = last_accessed_ts
|
||||
|
||||
def append_token_ids(self, token_ids: List[int]) -> None:
|
||||
"""Appends the given token IDs to the block and registers the block as
|
||||
immutable if the block becomes full.
|
||||
|
||||
Internally, the naive block handles CoW.
|
||||
|
||||
Args:
|
||||
token_ids (List[int]): The token IDs to be appended to the block.
|
||||
"""
|
||||
assert token_ids
|
||||
|
||||
# naive block handles CoW.
|
||||
self._block.append_token_ids(token_ids)
|
||||
|
||||
# If the content hash is present, then the block can be made immutable.
|
||||
# Register ourselves with the allocator, potentially replacing the
|
||||
# physical block index.
|
||||
if self.content_hash is not None:
|
||||
self.block_id = (self._prefix_caching_allocator.
|
||||
promote_to_immutable_block(self))
|
||||
|
||||
@property
|
||||
def block_id(self) -> Optional[int]:
|
||||
return self._block.block_id
|
||||
|
||||
@block_id.setter
|
||||
def block_id(self, value) -> None:
|
||||
self._block.block_id = value
|
||||
|
||||
@property
|
||||
def is_full(self) -> bool:
|
||||
return self._block.is_full
|
||||
|
||||
@property
|
||||
def num_empty_slots(self) -> int:
|
||||
return self._block.num_empty_slots
|
||||
|
||||
@property
|
||||
def num_tokens_total(self) -> int:
|
||||
"""return the total tokens so far.
|
||||
|
||||
Here we iterate the block chain till to the first block, while
|
||||
cache the result in local to prevent repeated computations.
|
||||
"""
|
||||
if self._cached_num_tokens_total is not None:
|
||||
return self._cached_num_tokens_total
|
||||
|
||||
_block: Optional[Block] = self
|
||||
self._cached_num_tokens_total = 0
|
||||
|
||||
# TODO: current implement here take O(N^2), we expect future
|
||||
# we have O(1) here
|
||||
while _block is not None:
|
||||
self._cached_num_tokens_total += len(_block.token_ids)
|
||||
_block = _block.prev_block
|
||||
|
||||
return self._cached_num_tokens_total
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block.block_size
|
||||
|
||||
@property
|
||||
def token_ids(self) -> List[int]:
|
||||
return self._block.token_ids
|
||||
|
||||
@property
|
||||
def prev_block(self) -> Optional[Block]:
|
||||
return self._prev_block
|
||||
|
||||
@property
|
||||
def content_hash(self) -> Optional[int]:
|
||||
"""Return the content-based hash of the current block, or None if it is
|
||||
not yet defined.
|
||||
|
||||
For the content-based hash to be defined, the current block must be
|
||||
full.
|
||||
"""
|
||||
|
||||
# If the hash is already computed, return it.
|
||||
if self._cached_content_hash is not None:
|
||||
return self._cached_content_hash
|
||||
|
||||
# We cannot compute a hash for the current block because it is not full.
|
||||
if not self.is_full:
|
||||
return None
|
||||
|
||||
is_first_block = self._prev_block is None
|
||||
prev_block_hash = (
|
||||
None if is_first_block else
|
||||
self._prev_block.content_hash # type: ignore
|
||||
)
|
||||
|
||||
# Previous block exists but does not yet have a hash.
|
||||
# Return no hash in this case.
|
||||
if prev_block_hash is None and not is_first_block:
|
||||
return None
|
||||
|
||||
self._cached_content_hash = PrefixCachingBlock.hash_block_tokens(
|
||||
is_first_block,
|
||||
prev_block_hash,
|
||||
cur_block_token_ids=self.token_ids)
|
||||
return self._cached_content_hash
|
||||
|
||||
@staticmethod
|
||||
def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
|
||||
cur_block_token_ids: List[int]) -> int:
|
||||
"""Computes a hash value corresponding to the contents of a block and
|
||||
the contents of the preceding block(s). The hash value is used for
|
||||
prefix caching.
|
||||
|
||||
NOTE: Content-based hashing does not yet support LoRA.
|
||||
|
||||
Parameters:
|
||||
- is_first_block (bool): A flag indicating if the block is the first in
|
||||
the sequence.
|
||||
- prev_block_hash (Optional[int]): The hash of the previous block. None
|
||||
if this is the first block.
|
||||
- cur_block_token_ids (List[int]): A list of token ids in the current
|
||||
block. The current block is assumed to be full.
|
||||
|
||||
Returns:
|
||||
- int: The computed hash value for the block.
|
||||
"""
|
||||
assert (prev_block_hash is None) == is_first_block
|
||||
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
|
||||
|
||||
|
||||
def assert_prefix_caching_block_or_none(block: Optional[Block]):
|
||||
if block is None:
|
||||
return
|
||||
assert isinstance(block, PrefixCachingBlock)
|
||||
625
vllm/core/block_manager_v1.py
Normal file
625
vllm/core/block_manager_v1.py
Normal file
@@ -0,0 +1,625 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import count, takewhile
|
||||
from os.path import commonprefix
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.core.evictor_v1 import EvictionPolicy, Evictor, make_evictor
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BlockAllocatorBase(ABC):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_total_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
pass
|
||||
|
||||
|
||||
class CachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
self.current_num_blocks = 0
|
||||
self.cached_blocks: Dict[int, PhysicalTokenBlock] = {}
|
||||
|
||||
self.evictor: Evictor = make_evictor(eviction_policy)
|
||||
|
||||
self.default_hash_ctr = count()
|
||||
|
||||
def allocate_block(self, block_hash: int,
|
||||
num_hashed_tokens: int) -> PhysicalTokenBlock:
|
||||
if self.current_num_blocks == self.num_blocks:
|
||||
block = self.evictor.evict()
|
||||
block.block_hash = block_hash
|
||||
block.num_hashed_tokens = num_hashed_tokens
|
||||
return block
|
||||
block = PhysicalTokenBlock(device=self.device,
|
||||
block_number=self.current_num_blocks,
|
||||
block_size=self.block_size,
|
||||
block_hash=block_hash,
|
||||
num_hashed_tokens=num_hashed_tokens)
|
||||
self.current_num_blocks += 1
|
||||
return block
|
||||
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
if block_hash is None:
|
||||
block_hash = next(self.default_hash_ctr)
|
||||
if block_hash in self.evictor:
|
||||
assert block_hash not in self.cached_blocks
|
||||
block = self.evictor.remove(block_hash)
|
||||
assert block.ref_count == 0
|
||||
self.cached_blocks[block_hash] = block
|
||||
block.ref_count += 1
|
||||
assert block.block_hash == block_hash
|
||||
return block
|
||||
if block_hash not in self.cached_blocks:
|
||||
self.cached_blocks[block_hash] = self.allocate_block(
|
||||
block_hash, num_hashed_tokens)
|
||||
block = self.cached_blocks[block_hash]
|
||||
assert block.block_hash == block_hash
|
||||
block.ref_count += 1
|
||||
return block
|
||||
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
if block.ref_count == 0:
|
||||
raise ValueError(f"Double free! {block} is already freed.")
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
assert block.block_hash not in self.evictor
|
||||
self.evictor.add(block)
|
||||
|
||||
# Remove the block from the cached_blocks
|
||||
del self.cached_blocks[block.block_hash]
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return (self.num_blocks - self.current_num_blocks +
|
||||
self.evictor.num_blocks)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self.num_blocks
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
return block_hash in self.cached_blocks or block_hash in self.evictor
|
||||
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
# Update the hash of block and the cached_blocks dictionary.
|
||||
assert not self.contains_block(block_hash)
|
||||
old_hash = block.block_hash
|
||||
block.block_hash = block_hash
|
||||
del self.cached_blocks[old_hash]
|
||||
self.cached_blocks[block_hash] = block
|
||||
|
||||
|
||||
class UncachedBlockAllocator(BlockAllocatorBase):
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: BlockTable = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
block_size=block_size,
|
||||
block_hash=-1,
|
||||
num_hashed_tokens=0)
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def allocate(self,
|
||||
block_hash: Optional[int] = None,
|
||||
num_hashed_tokens: int = 0) -> PhysicalTokenBlock:
|
||||
if not self.free_blocks:
|
||||
raise ValueError("Out of memory! No free blocks are available.")
|
||||
block = self.free_blocks.pop()
|
||||
block.ref_count = 1
|
||||
return block
|
||||
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
if block.ref_count == 0:
|
||||
raise ValueError(f"Double free! {block} is already freed.")
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self.free_blocks)
|
||||
|
||||
def get_num_total_blocks(self) -> int:
|
||||
return self.num_blocks
|
||||
|
||||
def contains_block(self, block_hash: int) -> bool:
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
def update_hash(self, block_hash: int, block: PhysicalTokenBlock):
|
||||
raise NotImplementedError(
|
||||
"Invalid codepath for uncached block allocator.")
|
||||
|
||||
|
||||
class BlockSpaceManagerV1(BlockSpaceManager):
|
||||
"""Manages the mapping between logical and physical token blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
if enable_caching and sliding_window is not None:
|
||||
raise NotImplementedError(
|
||||
"Sliding window is not allowed with prefix caching enabled!")
|
||||
|
||||
self.block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
# Round up to nearest block size to regularize sliding window
|
||||
# allocation sizes.
|
||||
self.block_sliding_window = math.ceil(sliding_window / block_size)
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
if self.enable_caching:
|
||||
logger.info("Automatic prefix caching is enabled.")
|
||||
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||
Device.GPU, block_size, num_gpu_blocks)
|
||||
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
|
||||
Device.CPU, block_size, num_cpu_blocks)
|
||||
else:
|
||||
self.gpu_allocator = UncachedBlockAllocator(
|
||||
Device.GPU, block_size, num_gpu_blocks)
|
||||
self.cpu_allocator = UncachedBlockAllocator(
|
||||
Device.CPU, block_size, num_cpu_blocks)
|
||||
# Mapping: seq_id -> BlockTable.
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
if (self.num_total_gpu_blocks - num_required_blocks <
|
||||
self.watermark_blocks):
|
||||
return AllocStatus.NEVER
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
block_table: BlockTable = []
|
||||
for logical_idx in range(num_prompt_blocks):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
elif self.enable_caching:
|
||||
block = self.gpu_allocator.allocate(
|
||||
seq.hash_of_block(logical_idx),
|
||||
seq.num_hashed_tokens_of_block(logical_idx))
|
||||
else:
|
||||
block = self.gpu_allocator.allocate()
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
|
||||
def can_append_slots(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> bool:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "lookahead allocation not supported in BlockSpaceManagerV1"
|
||||
|
||||
# Simple heuristic: If there is at least one free block
|
||||
# for each sequence, we can append.
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
return num_seqs <= num_free_gpu_blocks
|
||||
|
||||
def _promote_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
last_block: PhysicalTokenBlock,
|
||||
) -> PhysicalTokenBlock:
|
||||
assert self.enable_caching
|
||||
|
||||
# Compute a new hash for the block so that it can be shared by other
|
||||
# Sequences
|
||||
new_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
|
||||
# if new_hash is already in the cached table, then free last_block
|
||||
# and return the cached version
|
||||
if self.gpu_allocator.contains_block(new_hash):
|
||||
self.gpu_allocator.free(last_block)
|
||||
return self.gpu_allocator.allocate(new_hash)
|
||||
else:
|
||||
self.gpu_allocator.update_hash(new_hash, last_block)
|
||||
return last_block
|
||||
|
||||
def _is_last_block_full(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> bool:
|
||||
token_ids_len = seq.data.get_len()
|
||||
return token_ids_len > 0 and token_ids_len % seq.block_size == 0
|
||||
|
||||
def _maybe_promote_last_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
last_block: PhysicalTokenBlock,
|
||||
) -> PhysicalTokenBlock:
|
||||
if self._is_last_block_full(seq):
|
||||
return self._promote_last_block(seq, last_block)
|
||||
else:
|
||||
return last_block
|
||||
|
||||
def _allocate_last_physical_block(
|
||||
self,
|
||||
seq: Sequence,
|
||||
) -> PhysicalTokenBlock:
|
||||
# Called before a new block is appended.
|
||||
# This is in charge of allocating a new physical block (to be appended).
|
||||
|
||||
# None if the last block is not full. Otherwise, we set it to the
|
||||
# content hash.
|
||||
if not self.enable_caching:
|
||||
return self.gpu_allocator.allocate()
|
||||
block_hash: Optional[int] = None
|
||||
if (self._is_last_block_full(seq)):
|
||||
block_hash = seq.hash_of_block(len(seq.logical_token_blocks) - 1)
|
||||
num_hashed_tokens = seq.num_hashed_tokens_of_block(
|
||||
len(seq.logical_token_blocks) - 1)
|
||||
|
||||
# num_hashed_tokens is used to compute future hashes
|
||||
# (e.g. in the hashing function, it is used to ask the sequence for
|
||||
# prefix tokens)
|
||||
new_block = self.gpu_allocator.allocate(block_hash, num_hashed_tokens)
|
||||
|
||||
# If the block has is None, then the block is not full.
|
||||
# If the block is not full, then we expect it to have a refcount of 1.
|
||||
if block_hash is None:
|
||||
assert new_block.ref_count == 1
|
||||
return new_block
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int = 0,
|
||||
) -> Dict[int, List[int]]:
|
||||
"""Allocate a physical slot for a new token."""
|
||||
logical_blocks = seq.logical_token_blocks
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
# If we need to allocate a new physical block
|
||||
if len(block_table) < len(logical_blocks):
|
||||
# Currently this code only supports adding one physical block
|
||||
assert len(block_table) == len(logical_blocks) - 1
|
||||
|
||||
if (self.block_sliding_window
|
||||
and len(block_table) >= self.block_sliding_window):
|
||||
# reuse a block
|
||||
block_table.append(block_table[len(block_table) %
|
||||
self.block_sliding_window])
|
||||
else:
|
||||
# The sequence hash a new logical block.
|
||||
# Allocate a new physical block.
|
||||
new_block = self._allocate_last_physical_block(seq)
|
||||
block_table.append(new_block)
|
||||
return {}
|
||||
|
||||
# We want to append the token to the last physical block.
|
||||
last_block = block_table[-1]
|
||||
assert last_block.device == Device.GPU
|
||||
if last_block.ref_count == 1:
|
||||
# Not shared with other sequences. Appendable.
|
||||
if self.enable_caching:
|
||||
# If the last block is now complete, we may reuse an old block
|
||||
# to save memory.
|
||||
maybe_new_block = self._maybe_promote_last_block(
|
||||
seq, last_block)
|
||||
block_table[-1] = maybe_new_block
|
||||
return {}
|
||||
else:
|
||||
# The last block is shared with other sequences.
|
||||
# Copy on Write: Allocate a new block and copy the tokens.
|
||||
new_block = self._allocate_last_physical_block(seq)
|
||||
|
||||
block_table[-1] = new_block
|
||||
self.gpu_allocator.free(last_block)
|
||||
return {last_block.block_number: [new_block.block_number]}
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
# NOTE: fork does not allocate a new physical block.
|
||||
# Thus, it is always safe from OOM.
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.copy()
|
||||
# When using a sliding window, blocks will be eventually reused.
|
||||
# In this case the block tables will contain repeated blocks.
|
||||
# When forking, we must make sure that each block's `ref_count`
|
||||
# is only incremented by one, so we deduplicate them by wrapping
|
||||
# them in a set.
|
||||
for block in set(src_block_table):
|
||||
block.ref_count += 1
|
||||
|
||||
def _get_physical_blocks(
|
||||
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||
# NOTE: Here, we assume that the physical blocks are only shared by
|
||||
# the sequences in the same group.
|
||||
blocks: Set[PhysicalTokenBlock] = set()
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
blocks.update(self.block_tables[seq.seq_id])
|
||||
return list(blocks)
|
||||
|
||||
def can_swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# NOTE: Conservatively, we assume that every sequence will allocate
|
||||
# at least one free block right after the swap-in.
|
||||
# NOTE: This should match the logic in can_append_slot().
|
||||
num_required_blocks = len(blocks) + num_swapped_seqs
|
||||
if self.gpu_allocator.get_num_total_blocks() < num_required_blocks:
|
||||
return AllocStatus.NEVER
|
||||
elif num_free_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def swap_in(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> Dict[int, int]:
|
||||
assert (num_lookahead_slots == 0
|
||||
), "BlockSpaceManagerV1 does not support lookahead allocation"
|
||||
|
||||
# CPU block -> GPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
for cpu_block in block_table:
|
||||
if cpu_block in mapping:
|
||||
gpu_block = mapping[cpu_block]
|
||||
gpu_block.ref_count += 1
|
||||
else:
|
||||
gpu_block = self.gpu_allocator.allocate(
|
||||
cpu_block.block_hash, cpu_block.num_hashed_tokens)
|
||||
mapping[cpu_block] = gpu_block
|
||||
new_block_table.append(gpu_block)
|
||||
# Free the CPU block swapped in to GPU.
|
||||
self.cpu_allocator.free(cpu_block)
|
||||
self.block_tables[seq.seq_id] = new_block_table
|
||||
|
||||
block_number_mapping = {
|
||||
cpu_block.block_number: gpu_block.block_number
|
||||
for cpu_block, gpu_block in mapping.items()
|
||||
}
|
||||
return block_number_mapping
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# GPU block -> CPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
for gpu_block in block_table:
|
||||
if gpu_block in mapping:
|
||||
cpu_block = mapping[gpu_block]
|
||||
cpu_block.ref_count += 1
|
||||
else:
|
||||
cpu_block = self.cpu_allocator.allocate(
|
||||
gpu_block.block_hash, gpu_block.num_hashed_tokens)
|
||||
mapping[gpu_block] = cpu_block
|
||||
new_block_table.append(cpu_block)
|
||||
# Free the GPU block swapped out to CPU.
|
||||
self.gpu_allocator.free(gpu_block)
|
||||
self.block_tables[seq.seq_id] = new_block_table
|
||||
|
||||
block_number_mapping = {
|
||||
gpu_block.block_number: cpu_block.block_number
|
||||
for gpu_block, cpu_block in mapping.items()
|
||||
}
|
||||
return block_number_mapping
|
||||
|
||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||
# when using a sliding window, each seq will only use up
|
||||
# to `self.block_sliding_window` blocks. When freeing
|
||||
# the block table, we must make sure to not free blocks more
|
||||
# than once. If no sliding window is used, there is no block
|
||||
# reuse in the block table, so we must free all blocks.
|
||||
blocks_to_free = (block_table[-self.block_sliding_window:]
|
||||
if self.block_sliding_window is not None else
|
||||
block_table)
|
||||
for block in set(blocks_to_free):
|
||||
if block.device == Device.GPU:
|
||||
self.gpu_allocator.free(block)
|
||||
else:
|
||||
self.cpu_allocator.free(block)
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
self._free_block_table(block_table)
|
||||
del self.block_tables[seq.seq_id]
|
||||
|
||||
def reset(self) -> None:
|
||||
for block_table in self.block_tables.values():
|
||||
self._free_block_table(block_table)
|
||||
self.block_tables.clear()
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
return [block.block_number for block in block_table]
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.gpu_allocator.get_num_free_blocks()
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.cpu_allocator.get_num_free_blocks()
|
||||
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
if self.enable_caching:
|
||||
# Update the last accessed time of all the blocks accessed
|
||||
# in this step.
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
for block in block_table:
|
||||
block.last_accessed = access_time
|
||||
|
||||
def compute_full_blocks_in_seq(self, seq: Sequence):
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return
|
||||
max_full_block = seq.get_len() // self.block_size - 1
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if max_full_block == -1:
|
||||
return
|
||||
for i in reversed(range(max_full_block)):
|
||||
if block_table[i].computed:
|
||||
break
|
||||
block_table[i].computed = True
|
||||
|
||||
def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
return []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
# NOTE We exclude the last block to avoid the case where the entire
|
||||
# prompt is cached. This would cause erroneous behavior in model
|
||||
# runner.
|
||||
return [
|
||||
b.block_number
|
||||
for b in takewhile(lambda b: b.computed, block_table[:-1])
|
||||
]
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Return the block ids that are common for a given sequence group.
|
||||
|
||||
Used in prefill (can skip prefill of some blocks).
|
||||
"""
|
||||
# Can return non-empty result only with prefix caching enabled.
|
||||
if not self.enable_caching:
|
||||
return []
|
||||
|
||||
ids_list = [self.get_all_computed_blocks(seq) for seq in seqs]
|
||||
return commonprefix([ids for ids in ids_list if ids != []])
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
if self.enable_caching:
|
||||
for seq in seq_group.seqs_dict.values():
|
||||
self.compute_full_blocks_in_seq(seq)
|
||||
258
vllm/core/block_manager_v2.py
Normal file
258
vllm/core/block_manager_v2.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
|
||||
from vllm.core.block.block_table import BlockTable
|
||||
from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
SeqId = int
|
||||
|
||||
|
||||
class BlockSpaceManagerV2(BlockSpaceManager):
|
||||
"""BlockSpaceManager which manages the allocation of KV cache.
|
||||
|
||||
It owns responsibility for allocation, swapping, allocating memory for
|
||||
autoregressively-generated tokens, and other advanced features such as
|
||||
prefix caching, forking/copy-on-write, and sliding-window memory allocation.
|
||||
|
||||
The current implementation is partial; in particular prefix caching and
|
||||
sliding-window are not feature complete. This class implements the design
|
||||
described in https://github.com/vllm-project/vllm/pull/3492.
|
||||
|
||||
Lookahead slots
|
||||
The block manager has the notion of a "lookahead slot". These are slots
|
||||
in the KV cache that are allocated for a sequence. Unlike the other
|
||||
allocated slots, the content of these slots is undefined -- the worker
|
||||
may use the memory allocations in any way.
|
||||
|
||||
In practice, a worker could use these lookahead slots to run multiple
|
||||
forward passes for a single scheduler invocation. Each successive
|
||||
forward pass would write KV activations to the corresponding lookahead
|
||||
slot. This allows low inter-token latency use-cases, where the overhead
|
||||
of continuous batching scheduling is amortized over >1 generated tokens.
|
||||
|
||||
Speculative decoding uses lookahead slots to store KV activations of
|
||||
proposal tokens.
|
||||
|
||||
See https://github.com/vllm-project/vllm/pull/3250 for more information
|
||||
on lookahead scheduling.
|
||||
|
||||
Args:
|
||||
block_size (int): The size of each memory block.
|
||||
num_gpu_blocks (int): The number of memory blocks allocated on GPU.
|
||||
num_cpu_blocks (int): The number of memory blocks allocated on CPU.
|
||||
watermark (float, optional): The threshold used for memory swapping.
|
||||
Defaults to 0.01.
|
||||
sliding_window (Optional[int], optional): The size of the sliding
|
||||
window. Defaults to None.
|
||||
enable_caching (bool, optional): Flag indicating whether caching is
|
||||
enabled. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
assert sliding_window is None, "Sliding window not yet supported"
|
||||
|
||||
self.block_sliding_window = None
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.enable_caching = enable_caching
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
self.block_allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching" if enable_caching else "naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
self.block_tables: Dict[SeqId, BlockTable] = {}
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
|
||||
num_required_blocks = BlockTable.get_num_required_blocks(
|
||||
seq.get_token_ids(),
|
||||
block_size=self.block_size,
|
||||
)
|
||||
|
||||
assert self.block_sliding_window is None
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
device=Device.GPU)
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
if (self.num_total_gpu_blocks - num_required_blocks <
|
||||
self.watermark_blocks):
|
||||
return AllocStatus.NEVER
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING)
|
||||
assert not (set(seq.seq_id for seq in waiting_seqs)
|
||||
& self.block_tables.keys()), "block table already exists"
|
||||
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = waiting_seqs[0]
|
||||
|
||||
block_table = BlockTable(
|
||||
block_size=self.block_size,
|
||||
block_allocator=self.block_allocator,
|
||||
)
|
||||
assert self.block_sliding_window is None
|
||||
block_table.allocate(seq.get_token_ids())
|
||||
self.block_tables[seq.seq_id] = block_table
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in waiting_seqs[1:]:
|
||||
self.block_tables[seq.seq_id] = block_table.fork()
|
||||
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
"""Determine if there is enough space in the GPU KV cache to continue
|
||||
generation of the specified sequence group.
|
||||
|
||||
We use a worst-case heuristic: assume each touched block will require a
|
||||
new allocation (either via CoW or new block). We can append slots if the
|
||||
number of touched blocks is less than the number of free blocks.
|
||||
|
||||
"Lookahead slots" are slots that are allocated in addition to the slots
|
||||
for known tokens. The contents of the lookahead slots are not defined.
|
||||
This is used by speculative decoding when speculating future tokens.
|
||||
"""
|
||||
|
||||
num_touched_blocks = 0
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
num_touched_blocks += (
|
||||
block_table.get_num_blocks_touched_by_append_slots(
|
||||
token_ids=block_table.get_unseen_token_ids(
|
||||
seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
))
|
||||
|
||||
num_free_gpu_blocks = self.block_allocator.get_num_free_blocks(
|
||||
Device.GPU)
|
||||
return num_touched_blocks <= num_free_gpu_blocks
|
||||
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> Dict[int, List[int]]:
|
||||
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
block_table.append_token_ids(
|
||||
token_ids=block_table.get_unseen_token_ids(seq.get_token_ids()),
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
|
||||
# Return any new copy-on-writes.
|
||||
new_cows = self.block_allocator.clear_copy_on_writes()
|
||||
return new_cows
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
self.block_tables[seq.seq_id].free()
|
||||
del self.block_tables[seq.seq_id]
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
assert seq.seq_id in self.block_tables
|
||||
block_ids = self.block_tables[seq.seq_id].physical_block_ids
|
||||
assert all(b is not None for b in block_ids)
|
||||
return block_ids # type: ignore
|
||||
|
||||
def access_all_blocks_in_seq(self, seq: Sequence, now: float):
|
||||
# Update the last accessed time of all the blocks accessed
|
||||
# in this step.
|
||||
# And the accessed time is only useful for prefix caching now,
|
||||
# as it support internal evictor policy for which cached
|
||||
# block could be refilled, to keep cached content could be reused
|
||||
# at max extend.
|
||||
if self.enable_caching:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
block_ids = []
|
||||
for block_id in block_table.physical_block_ids:
|
||||
block_ids.append(block_id)
|
||||
self.block_allocator.mark_blocks_as_accessed(
|
||||
block_ids, # type: ignore
|
||||
now)
|
||||
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
# The only need for mark block as computed is for prefix caching,
|
||||
# while currently we could determine whether one block is computed
|
||||
# or not by check whether it has content hash.
|
||||
# So this function is useless for block_v2.
|
||||
pass
|
||||
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
"""Determine which blocks for which we skip prefill.
|
||||
|
||||
With prefix caching we can skip prefill for previously-generated blocks.
|
||||
Currently, the attention implementation only supports skipping cached
|
||||
blocks if they are a contiguous prefix of cached blocks.
|
||||
|
||||
This method determines which blocks can be safely skipped for all
|
||||
sequences in the sequence group.
|
||||
"""
|
||||
seq_block_ids = [
|
||||
self.block_tables[seq.seq_id].physical_block_ids for seq in seqs
|
||||
]
|
||||
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
|
||||
return self.block_allocator.get_common_computed_block_ids(
|
||||
seq_block_ids) # type: ignore
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.fork()
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> Dict[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
return False
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.GPU)
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.block_allocator.get_num_free_blocks(Device.CPU)
|
||||
105
vllm/core/evictor_v1.py
Normal file
105
vllm/core/evictor_v1.py
Normal file
@@ -0,0 +1,105 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import OrderedDict
|
||||
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
|
||||
|
||||
class EvictionPolicy(enum.Enum):
|
||||
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
||||
Evictor subclass.
|
||||
"""
|
||||
LRU = enum.auto()
|
||||
|
||||
|
||||
class Evictor(ABC):
|
||||
"""The Evictor subclasses should be used by the BlockAllocator class to
|
||||
handle eviction of freed PhysicalTokenBlocks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, block_hash: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self) -> PhysicalTokenBlock:
|
||||
"""Runs the eviction algorithm and returns the evicted block"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self, block: PhysicalTokenBlock):
|
||||
"""Adds block to the evictor, making it a candidate for eviction"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||
"""Simply removes the block with the hash value block_hash from the
|
||||
evictor. Caller is responsible for making sure that block_hash is
|
||||
contained in the evictor before calling remove. Should be used to
|
||||
"bring back" blocks that have been freed but not evicted yet.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def num_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class LRUEvictor(Evictor):
|
||||
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
||||
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
|
||||
the same last_accessed time, then the one with the largest num_hashed_tokens
|
||||
will be evicted. If two blocks each have the lowest last_accessed time and
|
||||
highest num_hashed_tokens value, then one will be chose arbitrarily
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.free_table: OrderedDict[int, PhysicalTokenBlock] = OrderedDict()
|
||||
|
||||
def __contains__(self, block_hash: int) -> bool:
|
||||
return block_hash in self.free_table
|
||||
|
||||
def evict(self) -> PhysicalTokenBlock:
|
||||
if len(self.free_table) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
evicted_block = next(iter(self.free_table.values()))
|
||||
# The blocks with the lowest timestamps should be placed consecutively
|
||||
# at the start of OrderedDict. Loop through all these blocks to
|
||||
# find the one with maximum number of hashed tokens.
|
||||
for _, block in self.free_table.items():
|
||||
if evicted_block.last_accessed < block.last_accessed:
|
||||
break
|
||||
if evicted_block.num_hashed_tokens < block.num_hashed_tokens:
|
||||
evicted_block = block
|
||||
|
||||
self.free_table.pop(evicted_block.block_hash)
|
||||
|
||||
evicted_block.computed = False
|
||||
return evicted_block
|
||||
|
||||
def add(self, block: PhysicalTokenBlock):
|
||||
self.free_table[block.block_hash] = block
|
||||
|
||||
def remove(self, block_hash: int) -> PhysicalTokenBlock:
|
||||
if block_hash not in self.free_table:
|
||||
raise ValueError(
|
||||
"Attempting to remove block that's not in the evictor")
|
||||
block: PhysicalTokenBlock = self.free_table[block_hash]
|
||||
self.free_table.pop(block_hash)
|
||||
return block
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return len(self.free_table)
|
||||
|
||||
|
||||
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||
if eviction_policy == EvictionPolicy.LRU:
|
||||
return LRUEvictor()
|
||||
else:
|
||||
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
||||
127
vllm/core/evictor_v2.py
Normal file
127
vllm/core/evictor_v2.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import OrderedDict, Tuple
|
||||
|
||||
|
||||
class EvictionPolicy(enum.Enum):
|
||||
"""Enum for eviction policy used by make_evictor to instantiate the correct
|
||||
Evictor subclass.
|
||||
"""
|
||||
LRU = enum.auto()
|
||||
|
||||
|
||||
class Evictor(ABC):
|
||||
"""The Evictor subclasses should be used by the BlockAllocator class to
|
||||
handle eviction of freed PhysicalTokenBlocks.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __contains__(self, block_id: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evict(self) -> Tuple[int, int]:
|
||||
"""Runs the eviction algorithm and returns the evicted block's
|
||||
content hash along with physical block id along with physical block id
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
"""Adds block to the evictor, making it a candidate for eviction"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def update(self, block_id: int, last_accessed: float):
|
||||
"""Update corresponding block's access time in metadata"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def remove(self, block_id: int):
|
||||
"""Remove a given block id from the cache."""
|
||||
pass
|
||||
|
||||
@abstractproperty
|
||||
def num_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
|
||||
class BlockMetaData():
|
||||
"""Data structure for storing key data describe cached block, so that
|
||||
evitor could use to make its decision which one to choose for eviction
|
||||
|
||||
Here we use physical block id as the dict key, as there maybe several
|
||||
blocks with the same content hash, but their physical id is unique.
|
||||
"""
|
||||
|
||||
def __init__(self, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
self.content_hash = content_hash
|
||||
self.num_hashed_tokens = num_hashed_tokens
|
||||
self.last_accessed = last_accessed
|
||||
|
||||
|
||||
class LRUEvictor(Evictor):
|
||||
"""Evicts in a least-recently-used order using the last_accessed timestamp
|
||||
that's recorded in the PhysicalTokenBlock. If there are multiple blocks with
|
||||
the same last_accessed time, then the one with the largest num_hashed_tokens
|
||||
will be evicted. If two blocks each have the lowest last_accessed time and
|
||||
highest num_hashed_tokens value, then one will be chose arbitrarily
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.free_table: OrderedDict[int, BlockMetaData] = OrderedDict()
|
||||
|
||||
def __contains__(self, block_id: int) -> bool:
|
||||
return block_id in self.free_table
|
||||
|
||||
def evict(self) -> Tuple[int, int]:
|
||||
if len(self.free_table) == 0:
|
||||
raise ValueError("No usable cache memory left")
|
||||
|
||||
evicted_block = next(iter(self.free_table.values()))
|
||||
evicted_block_id = next(iter(self.free_table.keys()))
|
||||
# The blocks with the lowest timestamps should be placed consecutively
|
||||
# at the start of OrderedDict. Loop through all these blocks to
|
||||
# find the one with maximum number of hashed tokens.
|
||||
for _id, block in self.free_table.items():
|
||||
if evicted_block.last_accessed > block.last_accessed or (
|
||||
evicted_block.last_accessed == block.last_accessed and
|
||||
evicted_block.num_hashed_tokens < block.num_hashed_tokens):
|
||||
evicted_block = block
|
||||
evicted_block_id = _id
|
||||
|
||||
self.free_table.pop(evicted_block_id)
|
||||
|
||||
return evicted_block_id, evicted_block.content_hash
|
||||
|
||||
def add(self, block_id: int, content_hash: int, num_hashed_tokens: int,
|
||||
last_accessed: float):
|
||||
self.free_table[block_id] = BlockMetaData(content_hash,
|
||||
num_hashed_tokens,
|
||||
last_accessed)
|
||||
|
||||
def update(self, block_id: int, last_accessed: float):
|
||||
self.free_table[block_id].last_accessed = last_accessed
|
||||
|
||||
def remove(self, block_id: int):
|
||||
if block_id not in self.free_table:
|
||||
raise ValueError(
|
||||
"Attempting to remove block that's not in the evictor")
|
||||
self.free_table.pop(block_id)
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return len(self.free_table)
|
||||
|
||||
|
||||
def make_evictor(eviction_policy: EvictionPolicy) -> Evictor:
|
||||
if eviction_policy == EvictionPolicy.LRU:
|
||||
return LRUEvictor()
|
||||
else:
|
||||
raise ValueError(f"Unknown cache eviction policy: {eviction_policy}")
|
||||
113
vllm/core/interfaces.py
Normal file
113
vllm/core/interfaces.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List
|
||||
from typing import Sequence as GenericSequence
|
||||
|
||||
from vllm.sequence import Sequence, SequenceGroup
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
"""Result for BlockSpaceManager.can_allocate
|
||||
|
||||
1. Ok: seq_group can be allocated now.
|
||||
2. Later: seq_group cannot be allocated.
|
||||
The capacity of allocator is larger than seq_group required.
|
||||
3. Never: seq_group can never be allocated.
|
||||
The seq_group is too large to allocated in GPU.
|
||||
"""
|
||||
OK = enum.auto()
|
||||
LATER = enum.auto()
|
||||
NEVER = enum.auto()
|
||||
|
||||
|
||||
class BlockSpaceManager(ABC):
|
||||
|
||||
@staticmethod
|
||||
def get_block_space_manager_class(version: str):
|
||||
version = version.lower()
|
||||
|
||||
if version == "v1":
|
||||
from vllm.core.block_manager_v1 import BlockSpaceManagerV1
|
||||
return BlockSpaceManagerV1
|
||||
|
||||
if version == "v2":
|
||||
from vllm.core.block_manager_v2 import BlockSpaceManagerV2
|
||||
return BlockSpaceManagerV2
|
||||
|
||||
raise ValueError(f"Unknown version {version=}")
|
||||
|
||||
@abstractmethod
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_append_slots(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def append_slots(
|
||||
self,
|
||||
seq: Sequence,
|
||||
num_lookahead_slots: int,
|
||||
) -> Dict[int, List[int]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> AllocStatus:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_in(self, seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int) -> Dict[int, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, seq: Sequence) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def access_all_blocks_in_seq(
|
||||
self,
|
||||
seq: Sequence,
|
||||
access_time: float,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_common_computed_block_ids(
|
||||
self, seqs: List[Sequence]) -> GenericSequence[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def mark_blocks_as_computed(self, seq_group: SequenceGroup):
|
||||
pass
|
||||
45
vllm/core/policy.py
Normal file
45
vllm/core/policy.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
|
||||
class Policy:
|
||||
|
||||
def get_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_group: SequenceGroup,
|
||||
) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def sort_by_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_groups: Deque[SequenceGroup],
|
||||
) -> Deque[SequenceGroup]:
|
||||
return deque(
|
||||
sorted(
|
||||
seq_groups,
|
||||
key=lambda seq_group: self.get_priority(now, seq_group),
|
||||
reverse=True,
|
||||
))
|
||||
|
||||
|
||||
class FCFS(Policy):
|
||||
|
||||
def get_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_group: SequenceGroup,
|
||||
) -> float:
|
||||
return now - seq_group.metrics.arrival_time
|
||||
|
||||
|
||||
class PolicyFactory:
|
||||
|
||||
_POLICY_REGISTRY = {'fcfs': FCFS}
|
||||
|
||||
@classmethod
|
||||
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
|
||||
return cls._POLICY_REGISTRY[policy_name](**kwargs)
|
||||
1163
vllm/core/scheduler.py
Normal file
1163
vllm/core/scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user