Sync from v0.13
This commit is contained in:
0
vllm/v1/kv_offload/__init__.py
Normal file
0
vllm/v1/kv_offload/__init__.py
Normal file
161
vllm/v1/kv_offload/abstract.py
Normal file
161
vllm/v1/kv_offload/abstract.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
OffloadingManager class for managing KV data offloading in vLLM v1
|
||||
|
||||
This class runs in the scheduler, tracks which blocks are offloaded
|
||||
and their address.
|
||||
|
||||
The class provides the following primitives:
|
||||
lookup() - find the length of the maximal series of blocks,
|
||||
starting from the first one, that are all offloaded.
|
||||
prepare_load() - prepare given blocks to be read.
|
||||
The given blocks will be protected from eviction.
|
||||
This function returns a LoadSpec which encapsulates
|
||||
information required for performing the load.
|
||||
touch() - marks the give blocks as recently used. Can be used
|
||||
to track block's LRU. This function is separated from the
|
||||
prepare_load function to allow setting block recency even
|
||||
for blocks which do not need reading from the cache, such as
|
||||
blocks that are cached by the GPU prefix cache.
|
||||
complete_load() - mark blocks which were previously prepared to be
|
||||
loaded as done loading. This is to re-allow their eviction.
|
||||
prepare_store() - prepare the given blocks to be written.
|
||||
Returns a StoreSpec encapsulating offloading information,
|
||||
as well as a list of blocks that were evicted as a result.
|
||||
complete_store() - marks a previous store as completed.
|
||||
Following this call, the given blocks will become loadable.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
|
||||
|
||||
class LoadStoreSpec(ABC):
|
||||
"""
|
||||
Abstract metadata that encapsulates information allowing a worker
|
||||
to load, and optionally also to store, blocks of KV data.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def medium() -> str:
|
||||
"""
|
||||
Returns a string representation of the medium type
|
||||
this store/load targets.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class PrepareStoreOutput:
|
||||
block_hashes_to_store: list[BlockHash]
|
||||
store_spec: LoadStoreSpec
|
||||
block_hashes_evicted: list[BlockHash]
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingEvent:
|
||||
block_hashes: list[BlockHash]
|
||||
block_size: int
|
||||
medium: str
|
||||
# True if blocks are removed, False if stored
|
||||
removed: bool
|
||||
|
||||
|
||||
class OffloadingManager(ABC):
|
||||
@abstractmethod
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
"""
|
||||
Finds the length of the maximal series of blocks, starting from the
|
||||
first one, that are all offloaded.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks to lookup.
|
||||
|
||||
Returns:
|
||||
An integer representing the maximal number of blocks that
|
||||
are currently offloaded.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
|
||||
"""
|
||||
Prepare the given blocks to be read.
|
||||
The given blocks will be protected from eviction until
|
||||
complete_load is called.
|
||||
It assumes all given blocks are offloaded.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
|
||||
Returns:
|
||||
A LoadStoreSpec that can be used by a worker to locate and load
|
||||
the actual offloaded KV data.
|
||||
"""
|
||||
pass
|
||||
|
||||
def touch(self, block_hashes: Iterable[BlockHash]):
|
||||
"""
|
||||
Mark the given blocks as recently used.
|
||||
This could in practice mean moving them to the end of an LRU list.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
"""
|
||||
return
|
||||
|
||||
def complete_load(self, block_hashes: Iterable[BlockHash]):
|
||||
"""
|
||||
Marks previous blocks that were prepared to load as done loading.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def prepare_store(
|
||||
self, block_hashes: Iterable[BlockHash]
|
||||
) -> PrepareStoreOutput | None:
|
||||
"""
|
||||
Prepare the given blocks to be offloaded.
|
||||
The given blocks will be protected from eviction until
|
||||
complete_store is called.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
|
||||
Returns:
|
||||
A PrepareStoreOutput indicating which blocks need storing,
|
||||
where to store them (LoadStoreSpec), and list of blocks that
|
||||
were evicted as a result.
|
||||
None is returned if the blocks cannot be stored.
|
||||
"""
|
||||
pass
|
||||
|
||||
def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True):
|
||||
"""
|
||||
Marks blocks which were previously prepared to be stored, as stored.
|
||||
Following this call, the blocks become loadable.
|
||||
If if_success is False, blocks that were not marked as stored will be
|
||||
removed.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks.
|
||||
success: whether the blocks were stored successfully.
|
||||
"""
|
||||
return
|
||||
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
"""
|
||||
Take the offloading events from the manager.
|
||||
|
||||
Yields:
|
||||
New OffloadingEvents collected since the last call.
|
||||
"""
|
||||
return ()
|
||||
237
vllm/v1/kv_offload/arc_manager.py
Normal file
237
vllm/v1/kv_offload/arc_manager.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import (
|
||||
LoadStoreSpec,
|
||||
OffloadingEvent,
|
||||
OffloadingManager,
|
||||
PrepareStoreOutput,
|
||||
)
|
||||
from vllm.v1.kv_offload.backend import Backend, BlockStatus
|
||||
|
||||
|
||||
class ARCOffloadingManager(OffloadingManager):
|
||||
"""
|
||||
An OffloadingManager implementing the ARC (Adaptive Replacement Cache)
|
||||
eviction policy with a pluggable backend.
|
||||
|
||||
Data Structures:
|
||||
T1: Recent cache containing blocks accessed once.
|
||||
T2: Frequent cache containing blocks accessed multiple times.
|
||||
B1/B2: Ghost lists tracking recently evicted blocks from T1/T2.
|
||||
target_t1_size: Adaptive target size for the T1 partition.
|
||||
|
||||
Algorithm Flow:
|
||||
1. Cache lookup (lookup):
|
||||
Searches T1 and T2 for block hashes and counts consecutive hits
|
||||
until a miss or non-ready block is encountered.
|
||||
|
||||
2. Cache touch (touch) - Adaptive Learning:
|
||||
For each block_hash (in reverse order):
|
||||
- If in T1: Move to T2 (promotion from recent to frequent).
|
||||
- If in T2: Move to MRU position (end of queue).
|
||||
- If in B1 ghost list: Increase target_t1_size.
|
||||
- If in B2 ghost list: Decrease target_t1_size.
|
||||
|
||||
3. Block eviction (prepare_store) - Adaptive Replacement:
|
||||
Determines eviction source based on adaptive target:
|
||||
- If T1 size > target_t1_size: Evict from T1, add to B1.
|
||||
- Otherwise: Evict from T2, add to B2.
|
||||
Finally, bound each ghost list size.
|
||||
|
||||
4. Block insertion (prepare_store):
|
||||
New blocks are always inserted into T1 and removed from B1/B2 if
|
||||
present. Blocks may later be promoted to T2 during touch operations.
|
||||
|
||||
Adaptive Behavior:
|
||||
The algorithm self-tunes the recency vs. frequency trade-off:
|
||||
- B1 hit: Recent access patterns matter more → increase T1.
|
||||
- B2 hit: Frequent access patterns matter more → decrease T1.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Backend, enable_events: bool = False):
|
||||
self.backend: Backend = backend
|
||||
self.target_t1_size: float = 0.0
|
||||
self.t1: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
|
||||
self.t2: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
|
||||
# block_hash -> None (only care about presence)
|
||||
self.b1: OrderedDict[BlockHash, None] = OrderedDict()
|
||||
self.b2: OrderedDict[BlockHash, None] = OrderedDict()
|
||||
self.events: list[OffloadingEvent] | None = [] if enable_events else None
|
||||
self.cache_capacity: int = self.backend.get_num_free_blocks()
|
||||
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
hit_count = 0
|
||||
for block_hash in block_hashes:
|
||||
block = self.t1.get(block_hash) or self.t2.get(block_hash)
|
||||
if block is None or not block.is_ready:
|
||||
break
|
||||
hit_count += 1
|
||||
return hit_count
|
||||
|
||||
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
|
||||
blocks = []
|
||||
for block_hash in block_hashes:
|
||||
block = self.t1.get(block_hash) or self.t2.get(block_hash)
|
||||
assert block is not None, f"Block {block_hash!r} not found in cache"
|
||||
assert block.is_ready, f"Block {block_hash!r} is not ready for reading"
|
||||
|
||||
block.ref_cnt += 1
|
||||
blocks.append(block)
|
||||
|
||||
return self.backend.get_load_store_spec(block_hashes, blocks)
|
||||
|
||||
def touch(self, block_hashes: Iterable[BlockHash]):
|
||||
for block_hash in reversed(list(block_hashes)):
|
||||
if block_hash in self.t1:
|
||||
block = self.t1.pop(block_hash)
|
||||
if not block.is_ready:
|
||||
# block was just prepared to be stored, not really touched twice
|
||||
self.t1.move_to_end(block_hash)
|
||||
else:
|
||||
self.t2[block_hash] = block
|
||||
|
||||
elif block_hash in self.t2:
|
||||
self.t2.move_to_end(block_hash)
|
||||
|
||||
elif block_hash in self.b1:
|
||||
delta = max(1, len(self.b2) / len(self.b1))
|
||||
self.target_t1_size = min(
|
||||
self.target_t1_size + delta, self.cache_capacity
|
||||
)
|
||||
# move to MRU position (end) to keep it fresh in the ghost list
|
||||
self.b1.move_to_end(block_hash)
|
||||
|
||||
elif block_hash in self.b2:
|
||||
delta = max(1, len(self.b1) / len(self.b2))
|
||||
self.target_t1_size = max(self.target_t1_size - delta, 0)
|
||||
# move to MRU position (end) to keep it fresh in the ghost list
|
||||
self.b2.move_to_end(block_hash)
|
||||
|
||||
def complete_load(self, block_hashes: Iterable[BlockHash]):
|
||||
for block_hash in block_hashes:
|
||||
block = self.t1.get(block_hash) or self.t2.get(block_hash)
|
||||
assert block is not None, f"Block {block_hash!r} not found"
|
||||
assert block.ref_cnt > 0, f"Block {block_hash!r} ref_cnt is already 0"
|
||||
|
||||
block.ref_cnt -= 1
|
||||
|
||||
def prepare_store(
|
||||
self, block_hashes: Iterable[BlockHash]
|
||||
) -> PrepareStoreOutput | None:
|
||||
block_hashes_to_store = []
|
||||
for block_hash in block_hashes:
|
||||
if block_hash not in self.t1 and block_hash not in self.t2:
|
||||
block_hashes_to_store.append(block_hash)
|
||||
|
||||
if not block_hashes_to_store:
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=[],
|
||||
store_spec=self.backend.get_load_store_spec([], []),
|
||||
block_hashes_evicted=[],
|
||||
)
|
||||
|
||||
num_blocks_to_evict = (
|
||||
len(block_hashes_to_store) - self.backend.get_num_free_blocks()
|
||||
)
|
||||
|
||||
to_evict = []
|
||||
while num_blocks_to_evict > 0:
|
||||
block_to_evict = None
|
||||
if len(self.t1) >= int(self.target_t1_size):
|
||||
# try to evict the least recently used (oldest) block from T1
|
||||
for block_hash, block in self.t1.items():
|
||||
if block.ref_cnt == 0:
|
||||
block_to_evict = (block_hash, block)
|
||||
eviction_t = self.t1
|
||||
eviction_b = self.b1
|
||||
break
|
||||
if not block_to_evict:
|
||||
# try to evict the least recently used (oldest) block from T2
|
||||
for block_hash, block in self.t2.items():
|
||||
if block.ref_cnt == 0:
|
||||
block_to_evict = (block_hash, block)
|
||||
eviction_t = self.t2
|
||||
eviction_b = self.b2
|
||||
break
|
||||
else:
|
||||
# cannot evict enough blocks, cache is full of in-use items
|
||||
return None
|
||||
|
||||
block_hash, block = block_to_evict
|
||||
del eviction_t[block_hash]
|
||||
eviction_b[block_hash] = None
|
||||
to_evict.append(block_hash)
|
||||
self.backend.free(block)
|
||||
num_blocks_to_evict -= 1
|
||||
|
||||
for b in [self.b1, self.b2]:
|
||||
for i in range(len(b) - self.cache_capacity):
|
||||
b.popitem(last=False)
|
||||
|
||||
if to_evict and self.events is not None:
|
||||
self.events.append(
|
||||
OffloadingEvent(
|
||||
block_hashes=to_evict,
|
||||
block_size=self.backend.block_size,
|
||||
medium=self.backend.medium,
|
||||
removed=True,
|
||||
)
|
||||
)
|
||||
|
||||
blocks = self.backend.allocate_blocks(block_hashes_to_store)
|
||||
assert len(blocks) == len(block_hashes_to_store), (
|
||||
"Backend did not allocate the expected number of blocks"
|
||||
)
|
||||
|
||||
for block_hash, block in zip(block_hashes_to_store, blocks):
|
||||
self.t1[block_hash] = block
|
||||
|
||||
self.b1.pop(block_hash, None)
|
||||
self.b2.pop(block_hash, None)
|
||||
|
||||
store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks)
|
||||
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=block_hashes_to_store,
|
||||
store_spec=store_spec,
|
||||
block_hashes_evicted=to_evict,
|
||||
)
|
||||
|
||||
def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True):
|
||||
stored_block_hashes: list[BlockHash] = []
|
||||
|
||||
if success:
|
||||
for block_hash in block_hashes:
|
||||
block = self.t1.get(block_hash) or self.t2.get(block_hash)
|
||||
|
||||
if block is not None and not block.is_ready:
|
||||
block.ref_cnt = 0
|
||||
stored_block_hashes.append(block_hash)
|
||||
else:
|
||||
for block_hash in block_hashes:
|
||||
block = self.t1.pop(block_hash, None)
|
||||
|
||||
if block is None:
|
||||
block = self.t2.pop(block_hash, None)
|
||||
|
||||
if block is not None and not block.is_ready:
|
||||
self.backend.free(block)
|
||||
|
||||
if stored_block_hashes and self.events is not None:
|
||||
self.events.append(
|
||||
OffloadingEvent(
|
||||
block_hashes=stored_block_hashes,
|
||||
block_size=self.backend.block_size,
|
||||
medium=self.backend.medium,
|
||||
removed=False,
|
||||
)
|
||||
)
|
||||
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
if self.events is not None:
|
||||
yield from self.events
|
||||
self.events.clear()
|
||||
97
vllm/v1/kv_offload/backend.py
Normal file
97
vllm/v1/kv_offload/backend.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ctypes
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
||||
|
||||
|
||||
class BlockStatus(ctypes.Structure):
|
||||
"""
|
||||
Offloading status for a single block of KV data.
|
||||
Holds the following information:
|
||||
|
||||
ref_cnt - the current number of transfers using this block as a source.
|
||||
A value of -1 indicates the block is not yet ready to be read.
|
||||
load_store_spec - backend-specific information on how to actually
|
||||
read/write the block.
|
||||
"""
|
||||
|
||||
_fields_ = [("ref_cnt", ctypes.c_int32)]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# initialize block as "not ready" (ref_cnt = -1)
|
||||
self.ref_cnt = -1
|
||||
|
||||
@property
|
||||
def is_ready(self) -> bool:
|
||||
"""
|
||||
Returns whether the block is ready to be read.
|
||||
"""
|
||||
return self.ref_cnt >= 0
|
||||
|
||||
|
||||
class Backend(ABC):
|
||||
"""
|
||||
An abstract class for allocating and returning specs for writing
|
||||
KV blocks to some backend.
|
||||
"""
|
||||
|
||||
def __init__(self, block_size: int, medium: str):
|
||||
self.block_size = block_size
|
||||
self.medium = medium
|
||||
|
||||
@abstractmethod
|
||||
def get_num_free_blocks(self):
|
||||
"""
|
||||
Returns the number of current number of blocks that can be allocated.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]:
|
||||
"""
|
||||
Allocate space for writing blocks.
|
||||
This method assumes there is enough space for allocation.
|
||||
It is unsafe to use without checking get_num_free_blocks beforehand.
|
||||
|
||||
Args:
|
||||
block_hashes: the hashes identifying the blocks to be written.
|
||||
|
||||
Returns:
|
||||
A list of BlockStatus for the allocated blocks.
|
||||
The ref_cnt of each returned item will be -1, meaning the block
|
||||
is not yet ready to be read.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def free(self, block: BlockStatus):
|
||||
"""
|
||||
Free a previously allocated block.
|
||||
You should only call this function with blocks returned by
|
||||
allocate_blocks, and only once per each block.
|
||||
|
||||
Args:
|
||||
block: The block to be freed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_load_store_spec(
|
||||
self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus]
|
||||
) -> LoadStoreSpec:
|
||||
"""
|
||||
Get backend-specific information on how to read/write blocks.
|
||||
|
||||
Args:
|
||||
block_hashes: the list of block hashes identifying the blocks.
|
||||
blocks: the list of blocks.
|
||||
|
||||
Returns:
|
||||
A LoadStoreSpec that can be used by a worker
|
||||
to read/write the blocks.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
0
vllm/v1/kv_offload/backends/__init__.py
Normal file
0
vllm/v1/kv_offload/backends/__init__.py
Normal file
62
vllm/v1/kv_offload/backends/cpu.py
Normal file
62
vllm/v1/kv_offload/backends/cpu.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ctypes
|
||||
from collections.abc import Iterable
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
||||
from vllm.v1.kv_offload.backend import Backend, BlockStatus
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
|
||||
|
||||
|
||||
class CPUBlockStatus(BlockStatus):
|
||||
_fields_ = BlockStatus._fields_ + [("block_id", ctypes.c_int64)] # type: ignore
|
||||
|
||||
def __init__(self, block_id: int):
|
||||
super().__init__()
|
||||
self.block_id = block_id
|
||||
|
||||
|
||||
class CPUBackend(Backend):
|
||||
def __init__(self, block_size: int, num_blocks: int):
|
||||
super().__init__(block_size=block_size, medium=CPULoadStoreSpec.medium())
|
||||
|
||||
self.num_blocks: int = num_blocks
|
||||
self.num_allocated_blocks: int = 0
|
||||
self.allocated_blocks_free_list: list[int] = []
|
||||
|
||||
def get_num_free_blocks(self):
|
||||
return (
|
||||
len(self.allocated_blocks_free_list)
|
||||
+ self.num_blocks
|
||||
- self.num_allocated_blocks
|
||||
)
|
||||
|
||||
def allocate_blocks(self, block_hashes: list[BlockHash]) -> list[BlockStatus]:
|
||||
num_fresh_blocks = min(
|
||||
len(block_hashes), self.num_blocks - self.num_allocated_blocks
|
||||
)
|
||||
num_reused_blocks = len(block_hashes) - num_fresh_blocks
|
||||
assert len(self.allocated_blocks_free_list) >= num_reused_blocks
|
||||
|
||||
# allocate fresh blocks
|
||||
blocks: list[BlockStatus] = []
|
||||
for _ in range(num_fresh_blocks):
|
||||
blocks.append(CPUBlockStatus(self.num_allocated_blocks))
|
||||
self.num_allocated_blocks += 1
|
||||
|
||||
# allocate reused blocks
|
||||
for _ in range(num_reused_blocks):
|
||||
block_id = self.allocated_blocks_free_list.pop()
|
||||
blocks.append(CPUBlockStatus(block_id))
|
||||
|
||||
return blocks
|
||||
|
||||
def free(self, block: BlockStatus):
|
||||
assert isinstance(block, CPUBlockStatus)
|
||||
self.allocated_blocks_free_list.append(block.block_id)
|
||||
|
||||
def get_load_store_spec(
|
||||
self, block_hashes: Iterable[BlockHash], blocks: Iterable[BlockStatus]
|
||||
) -> LoadStoreSpec:
|
||||
return CPULoadStoreSpec([block.block_id for block in blocks])
|
||||
86
vllm/v1/kv_offload/cpu.py
Normal file
86
vllm/v1/kv_offload/cpu.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterator
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.arc_manager import ARCOffloadingManager
|
||||
from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
||||
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandlers
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||
|
||||
|
||||
class CPUOffloadingSpec(OffloadingSpec):
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
|
||||
if not num_cpu_blocks:
|
||||
raise Exception(
|
||||
"num_cpu_blocks must be specified in kv_connector_extra_config"
|
||||
)
|
||||
self.num_cpu_blocks: int = num_cpu_blocks
|
||||
|
||||
# scheduler-side
|
||||
self._manager: OffloadingManager | None = None
|
||||
|
||||
# worker-side
|
||||
self._handlers: CpuGpuOffloadingHandlers | None = None
|
||||
|
||||
self.eviction_policy: str = self.extra_config.get("eviction_policy", "lru")
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
if not self._manager:
|
||||
kv_events_config = self.vllm_config.kv_events_config
|
||||
enable_events = (
|
||||
kv_events_config is not None and kv_events_config.enable_kv_cache_events
|
||||
)
|
||||
|
||||
backend = CPUBackend(
|
||||
block_size=self.offloaded_block_size, num_blocks=self.num_cpu_blocks
|
||||
)
|
||||
|
||||
if self.eviction_policy == "lru":
|
||||
self._manager = LRUOffloadingManager(
|
||||
backend=backend, enable_events=enable_events
|
||||
)
|
||||
elif self.eviction_policy == "arc":
|
||||
self._manager = ARCOffloadingManager(
|
||||
backend=backend, enable_events=enable_events
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown eviction policy: {self.eviction_policy}. "
|
||||
f"Supported policies: lru, arc"
|
||||
)
|
||||
return self._manager
|
||||
|
||||
def get_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
if not self._handlers:
|
||||
if not current_platform.is_cuda_alike():
|
||||
raise Exception(
|
||||
"CPU Offloading is currently only supported on CUDA-alike GPUs"
|
||||
)
|
||||
|
||||
self._handlers = CpuGpuOffloadingHandlers(
|
||||
attn_backends=attn_backends,
|
||||
gpu_block_size=self.gpu_block_size,
|
||||
cpu_block_size=self.offloaded_block_size,
|
||||
num_cpu_blocks=self.num_cpu_blocks,
|
||||
gpu_caches=kv_caches,
|
||||
)
|
||||
|
||||
assert self._handlers is not None
|
||||
yield GPULoadStoreSpec, CPULoadStoreSpec, self._handlers.gpu_to_cpu_handler
|
||||
yield CPULoadStoreSpec, GPULoadStoreSpec, self._handlers.cpu_to_gpu_handler
|
||||
56
vllm/v1/kv_offload/factory.py
Normal file
56
vllm/v1/kv_offload/factory.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingSpecFactory:
|
||||
_registry: dict[str, Callable[[], type[OffloadingSpec]]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_spec(cls, name: str, module_path: str, class_name: str) -> None:
|
||||
"""Register a spec with a lazy-loading module and class name."""
|
||||
if name in cls._registry:
|
||||
raise ValueError(f"Connector '{name}' is already registered.")
|
||||
|
||||
def loader() -> type[OffloadingSpec]:
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
cls._registry[name] = loader
|
||||
|
||||
@classmethod
|
||||
def create_spec(
|
||||
cls,
|
||||
config: "VllmConfig",
|
||||
) -> OffloadingSpec:
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
spec_name = extra_config.get("spec_name", "CPUOffloadingSpec")
|
||||
if spec_name in cls._registry:
|
||||
spec_cls = cls._registry[spec_name]()
|
||||
else:
|
||||
spec_module_path = extra_config.get("spec_module_path")
|
||||
if spec_module_path is None:
|
||||
raise ValueError(f"Unsupported spec type: {spec_name}")
|
||||
spec_module = importlib.import_module(spec_module_path)
|
||||
spec_cls = getattr(spec_module, spec_name)
|
||||
assert issubclass(spec_cls, OffloadingSpec)
|
||||
logger.info("Creating offloading spec with name: %s", spec_name)
|
||||
return spec_cls(config)
|
||||
|
||||
|
||||
# Register various specs here.
|
||||
OffloadingSpecFactory.register_spec(
|
||||
"CPUOffloadingSpec", "vllm.v1.kv_offload.cpu", "CPUOffloadingSpec"
|
||||
)
|
||||
139
vllm/v1/kv_offload/lru_manager.py
Normal file
139
vllm/v1/kv_offload/lru_manager.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Iterable
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import (
|
||||
LoadStoreSpec,
|
||||
OffloadingEvent,
|
||||
OffloadingManager,
|
||||
PrepareStoreOutput,
|
||||
)
|
||||
from vllm.v1.kv_offload.backend import Backend, BlockStatus
|
||||
|
||||
|
||||
class LRUOffloadingManager(OffloadingManager):
|
||||
"""
|
||||
An OffloadingManager with a pluggable backend, which evicts blocks by LRU.
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Backend, enable_events: bool = False):
|
||||
self.backend: Backend = backend
|
||||
# block_hash -> BlockStatus
|
||||
self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
|
||||
self.events: list[OffloadingEvent] | None = [] if enable_events else None
|
||||
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
hit_count = 0
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks.get(block_hash)
|
||||
if block is None or not block.is_ready:
|
||||
break
|
||||
hit_count += 1
|
||||
return hit_count
|
||||
|
||||
def prepare_load(self, block_hashes: Iterable[BlockHash]) -> LoadStoreSpec:
|
||||
blocks = []
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
assert block.is_ready
|
||||
block.ref_cnt += 1
|
||||
blocks.append(block)
|
||||
|
||||
return self.backend.get_load_store_spec(block_hashes, blocks)
|
||||
|
||||
def touch(self, block_hashes: Iterable[BlockHash]):
|
||||
for block_hash in reversed(list(block_hashes)):
|
||||
if self.blocks.get(block_hash):
|
||||
self.blocks.move_to_end(block_hash)
|
||||
|
||||
def complete_load(self, block_hashes: Iterable[BlockHash]):
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
assert block.ref_cnt > 0
|
||||
block.ref_cnt -= 1
|
||||
|
||||
def prepare_store(
|
||||
self, block_hashes: Iterable[BlockHash]
|
||||
) -> PrepareStoreOutput | None:
|
||||
# filter out blocks that are already stored
|
||||
block_hashes_to_store = [
|
||||
block_hash for block_hash in block_hashes if block_hash not in self.blocks
|
||||
]
|
||||
|
||||
num_blocks_to_evict = (
|
||||
len(block_hashes_to_store) - self.backend.get_num_free_blocks()
|
||||
)
|
||||
|
||||
# build list of blocks to evict
|
||||
to_evict = []
|
||||
if num_blocks_to_evict > 0:
|
||||
for block_hash, block in self.blocks.items():
|
||||
if block.ref_cnt == 0:
|
||||
to_evict.append(block_hash)
|
||||
num_blocks_to_evict -= 1
|
||||
if num_blocks_to_evict == 0:
|
||||
break
|
||||
else:
|
||||
# we could not evict enough blocks
|
||||
return None
|
||||
|
||||
# evict blocks
|
||||
for block_hash in to_evict:
|
||||
self.backend.free(self.blocks.pop(block_hash))
|
||||
|
||||
if to_evict and self.events is not None:
|
||||
self.events.append(
|
||||
OffloadingEvent(
|
||||
block_hashes=to_evict,
|
||||
block_size=self.backend.block_size,
|
||||
medium=self.backend.medium,
|
||||
removed=True,
|
||||
)
|
||||
)
|
||||
|
||||
blocks = self.backend.allocate_blocks(block_hashes_to_store)
|
||||
assert len(blocks) == len(block_hashes_to_store)
|
||||
|
||||
for block_hash, block in zip(block_hashes_to_store, blocks):
|
||||
self.blocks[block_hash] = block
|
||||
|
||||
# build store specs for allocated blocks
|
||||
store_spec = self.backend.get_load_store_spec(block_hashes_to_store, blocks)
|
||||
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=block_hashes_to_store,
|
||||
store_spec=store_spec,
|
||||
block_hashes_evicted=to_evict,
|
||||
)
|
||||
|
||||
def complete_store(self, block_hashes: Iterable[BlockHash], success: bool = True):
|
||||
stored_block_hashes: list[BlockHash] = []
|
||||
if success:
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
if not block.is_ready:
|
||||
block.ref_cnt = 0
|
||||
stored_block_hashes.append(block_hash)
|
||||
else:
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks[block_hash]
|
||||
if not block.is_ready:
|
||||
self.backend.free(block)
|
||||
del self.blocks[block_hash]
|
||||
|
||||
if stored_block_hashes and self.events is not None:
|
||||
self.events.append(
|
||||
OffloadingEvent(
|
||||
block_hashes=stored_block_hashes,
|
||||
block_size=self.backend.block_size,
|
||||
medium=self.backend.medium,
|
||||
removed=False,
|
||||
)
|
||||
)
|
||||
|
||||
def take_events(self) -> Iterable[OffloadingEvent]:
|
||||
if self.events is not None:
|
||||
yield from self.events
|
||||
self.events.clear()
|
||||
39
vllm/v1/kv_offload/mediums.py
Normal file
39
vllm/v1/kv_offload/mediums.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
||||
|
||||
|
||||
class BlockIDsLoadStoreSpec(LoadStoreSpec, ABC):
|
||||
"""
|
||||
Spec for loading/storing KV blocks from given block numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, block_ids: list[int]):
|
||||
self.block_ids = np.array(block_ids, dtype=np.int64)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_ids)
|
||||
|
||||
|
||||
class GPULoadStoreSpec(BlockIDsLoadStoreSpec):
|
||||
"""
|
||||
Spec for loading/storing a KV block to GPU memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "GPU"
|
||||
|
||||
|
||||
class CPULoadStoreSpec(BlockIDsLoadStoreSpec):
|
||||
"""
|
||||
Spec for loading/storing a KV block to CPU memory.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "CPU"
|
||||
66
vllm/v1/kv_offload/spec.py
Normal file
66
vllm/v1/kv_offload/spec.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingSpec(ABC):
|
||||
"""Spec for an offloading connector"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig"):
|
||||
logger.warning(
|
||||
"Initializing OffloadingSpec. This API is experimental and "
|
||||
"subject to change in the future as we iterate the design."
|
||||
)
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
assert kv_transfer_config is not None
|
||||
self.extra_config = kv_transfer_config.kv_connector_extra_config
|
||||
|
||||
self.gpu_block_size = vllm_config.cache_config.block_size
|
||||
self.offloaded_block_size = int(
|
||||
self.extra_config.get("block_size", self.gpu_block_size)
|
||||
)
|
||||
|
||||
assert self.offloaded_block_size % self.gpu_block_size == 0
|
||||
|
||||
@abstractmethod
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
"""
|
||||
Get an OffloadingManager that will be used
|
||||
by the scheduler-side offloading connector to track
|
||||
offloaded blocks and manage evictions.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_handlers(
|
||||
self,
|
||||
kv_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], OffloadingHandler]]:
|
||||
"""
|
||||
Get offloading handlers along with their respective src and dst types.
|
||||
|
||||
Args:
|
||||
kv_caches: A dictionary of layer_name -> gpu_kv_cache tensor.
|
||||
attn_backends: A dictionary of layer_name -> AttentionBackend.
|
||||
|
||||
Yields:
|
||||
Tuples of (src_type, dst_type, offloading_handler).
|
||||
"""
|
||||
pass
|
||||
0
vllm/v1/kv_offload/worker/__init__.py
Normal file
0
vllm/v1/kv_offload/worker/__init__.py
Normal file
280
vllm/v1/kv_offload/worker/cpu_gpu.py
Normal file
280
vllm/v1/kv_offload/worker/cpu_gpu.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.kv_offload.mediums import BlockIDsLoadStoreSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (
|
||||
OffloadingHandler,
|
||||
TransferResult,
|
||||
TransferSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def expand_block_ids(
|
||||
block_ids: np.ndarray,
|
||||
block_size_factor: int,
|
||||
output: np.ndarray,
|
||||
skip_count: int = 0,
|
||||
):
|
||||
"""
|
||||
Convert a list of block IDs to a list of matching block ids,
|
||||
assuming each block is composed of actual block_size_factor blocks.
|
||||
Outputs to output tensor.
|
||||
The first skip_count blocks will be skipped.
|
||||
Note that skip_count must be less than block_size_factor.
|
||||
|
||||
For example, if block_ids = [0, 1, 3] and block_size_factor = 4,
|
||||
then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15]
|
||||
since 0 maps to [0, 1, 2, 3]
|
||||
1 maps to [4, 5, 6, 7]
|
||||
and 3 maps to [12, 13, 14, 15]
|
||||
"""
|
||||
assert skip_count < block_size_factor
|
||||
|
||||
first_range = np.arange(skip_count, block_size_factor)
|
||||
full_range = np.arange(0, block_size_factor)
|
||||
|
||||
output_idx = 0
|
||||
for i, block_id in enumerate(block_ids):
|
||||
base_block_id = block_id * block_size_factor
|
||||
indices = first_range if i == 0 else full_range
|
||||
output_end_idx = output_idx + len(indices)
|
||||
output[output_idx:output_end_idx] = base_block_id + indices
|
||||
output_idx = output_end_idx
|
||||
|
||||
|
||||
class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
"""
|
||||
SingleDirectionOffloadingHandler handles transfers for a single direction,
|
||||
either CPU->GPU or GPU->CPU.
|
||||
Transfers are guaranteed to be executed in order of their submission.
|
||||
Each transfer uses a unique CUDA stream, and its stream will start
|
||||
executing only after the streams of previous transfers have finished.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
src_tensors: list[torch.Tensor],
|
||||
dst_tensors: list[torch.Tensor],
|
||||
kv_dim_before_num_blocks: list[bool],
|
||||
src_block_size_factor: int,
|
||||
dst_block_size_factor: int,
|
||||
priority: int,
|
||||
):
|
||||
"""
|
||||
Initialize a SingleDirectionOffloadingHandler.
|
||||
|
||||
Args:
|
||||
src_tensors: list of KV cache tensors to copy from.
|
||||
dst_tensors: list of KV cache tensors to copy to.
|
||||
Order should match src_tensors.
|
||||
kv_dim_before_num_blocks: list of bools, indicating
|
||||
whether the respective KV cache tensor has a KV
|
||||
dimension before its num_blocks dimension.
|
||||
e.g. (2, num_blocks, ...)
|
||||
src_block_size_factor: The number of kernel blocks
|
||||
per KV block in a source tensor.
|
||||
dst_block_size_factor: The number of kernel blocks
|
||||
per KV block in a destination tensor.
|
||||
priority: The priority of the backing CUDA streams.
|
||||
Lower numbers indicate higher priority.
|
||||
"""
|
||||
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
|
||||
|
||||
self.src_tensors: list[torch.Tensor] = src_tensors
|
||||
self.dst_tensors: list[torch.Tensor] = dst_tensors
|
||||
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
|
||||
self.src_block_size_factor: int = src_block_size_factor
|
||||
self.dst_block_size_factor: int = dst_block_size_factor
|
||||
self.priority = priority
|
||||
|
||||
# queue of transfers (job_id, stream, event)
|
||||
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
|
||||
# list of CUDA streams available for re-use
|
||||
self._stream_pool: list[torch.cuda.Stream] = []
|
||||
# list of CUDA events available for re-use
|
||||
self._event_pool: list[torch.Event] = []
|
||||
|
||||
def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
|
||||
src_spec, dst_spec = transfer_spec
|
||||
assert isinstance(src_spec, BlockIDsLoadStoreSpec)
|
||||
assert isinstance(dst_spec, BlockIDsLoadStoreSpec)
|
||||
|
||||
src_blocks = src_spec.block_ids
|
||||
dst_blocks = dst_spec.block_ids
|
||||
assert src_blocks.ndim == 1
|
||||
assert dst_blocks.ndim == 1
|
||||
|
||||
src_sub_block_count = src_blocks.size * self.src_block_size_factor
|
||||
dst_sub_block_count = dst_blocks.size * self.dst_block_size_factor
|
||||
src_sub_blocks_to_skip = -dst_blocks.size % self.src_block_size_factor
|
||||
|
||||
assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
|
||||
|
||||
src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
|
||||
expand_block_ids(
|
||||
src_blocks,
|
||||
self.src_block_size_factor,
|
||||
src_to_dst[:, 0],
|
||||
skip_count=src_sub_blocks_to_skip,
|
||||
)
|
||||
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
|
||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||
|
||||
stream = (
|
||||
self._stream_pool.pop()
|
||||
if self._stream_pool
|
||||
else torch.cuda.Stream(priority=self.priority)
|
||||
)
|
||||
event = self._event_pool.pop() if self._event_pool else torch.Event()
|
||||
if self._transfers:
|
||||
_, _, last_event = self._transfers[-1]
|
||||
# assure job will start only after the previous one completes
|
||||
stream.wait_event(last_event)
|
||||
with torch.cuda.stream(stream):
|
||||
for src_tensor, dst_tensor, kv_dim in zip(
|
||||
self.src_tensors, self.dst_tensors, self.kv_dim_before_num_blocks
|
||||
):
|
||||
if kv_dim:
|
||||
src_key_cache, src_value_cache = src_tensor
|
||||
dst_key_cache, dst_value_cache = dst_tensor
|
||||
ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst_tensor)
|
||||
ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst_tensor)
|
||||
else:
|
||||
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
|
||||
event.record(stream)
|
||||
|
||||
self._transfers.append((job_id, stream, event))
|
||||
|
||||
# success
|
||||
return True
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
results: list[TransferResult] = []
|
||||
while self._transfers and self._transfers[0][2].query():
|
||||
job_id, stream, event = self._transfers.popleft()
|
||||
results.append((job_id, True))
|
||||
self._stream_pool.append(stream)
|
||||
self._event_pool.append(event)
|
||||
return results
|
||||
|
||||
|
||||
class CpuGpuOffloadingHandlers:
|
||||
def __init__(
|
||||
self,
|
||||
gpu_block_size: int,
|
||||
cpu_block_size: int,
|
||||
num_cpu_blocks: int,
|
||||
gpu_caches: dict[str, torch.Tensor],
|
||||
attn_backends: dict[str, type[AttentionBackend]],
|
||||
):
|
||||
assert gpu_caches
|
||||
assert cpu_block_size % gpu_block_size == 0
|
||||
block_size_factor = cpu_block_size // gpu_block_size
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
# allocate cpu tensors
|
||||
logger.info("Allocating %d CPU tensors...", len(gpu_caches))
|
||||
gpu_tensors: list[torch.Tensor] = []
|
||||
cpu_tensors: list[torch.Tensor] = []
|
||||
kv_dim_before_num_blocks: list[bool] = []
|
||||
kernel_block_size: int | None = None
|
||||
for layer_name, gpu_tensor in gpu_caches.items():
|
||||
gpu_tensors.append(gpu_tensor)
|
||||
|
||||
gpu_shape = gpu_tensor.shape
|
||||
attn_backend = attn_backends[layer_name]
|
||||
test_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks=1234, block_size=16, num_kv_heads=8, head_size=256
|
||||
)
|
||||
|
||||
has_layers_dim = False
|
||||
if len(gpu_shape) != len(test_shape):
|
||||
# cross-layers tensor
|
||||
# shape is (num_blocks, ...)
|
||||
assert len(gpu_shape) == len(test_shape) + 1
|
||||
num_blocks_idx = 0
|
||||
has_layers_dim = True
|
||||
kv_dim_before_num_blocks.append(False)
|
||||
|
||||
# prepend a dummy num_layers=80 to test_shape
|
||||
test_shape = (80,) + test_shape
|
||||
elif test_shape[0] == 1234:
|
||||
# shape is (num_blocks, ...)
|
||||
num_blocks_idx = 0
|
||||
kv_dim_before_num_blocks.append(False)
|
||||
else:
|
||||
# shape should be (2, num_blocks, ...)
|
||||
assert test_shape[0] == 2
|
||||
assert test_shape[1] == 1234
|
||||
assert gpu_shape[0] == 2
|
||||
|
||||
num_blocks_idx = 1
|
||||
kv_dim_before_num_blocks.append(True)
|
||||
|
||||
try:
|
||||
kv_cache_stride_order = attn_backend.get_kv_cache_stride_order(
|
||||
include_num_layers_dimension=has_layers_dim
|
||||
)
|
||||
assert len(kv_cache_stride_order) == len(gpu_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(range(len(gpu_shape)))
|
||||
|
||||
# permute test_shape according to stride_order
|
||||
test_shape = tuple(test_shape[i] for i in kv_cache_stride_order)
|
||||
|
||||
# find block_size (16) dimension index
|
||||
block_size_idx = test_shape.index(16)
|
||||
if kernel_block_size is not None:
|
||||
assert kernel_block_size == gpu_shape[block_size_idx]
|
||||
else:
|
||||
kernel_block_size = gpu_shape[block_size_idx]
|
||||
assert gpu_block_size % kernel_block_size == 0
|
||||
|
||||
cpu_shape = list(gpu_shape)
|
||||
cpu_shape[num_blocks_idx] = num_cpu_blocks * block_size_factor
|
||||
|
||||
logger.debug("Allocating CPU tensor of shape %r", cpu_shape)
|
||||
cpu_tensors.append(
|
||||
torch.zeros(
|
||||
cpu_shape,
|
||||
dtype=gpu_tensor.dtype,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
)
|
||||
|
||||
assert kernel_block_size is not None
|
||||
gpu_block_size_factor = gpu_block_size // kernel_block_size
|
||||
cpu_block_size_factor = cpu_block_size // kernel_block_size
|
||||
|
||||
# TODO (orozery): adapt swap_blocks to support gpu_block_size_factor
|
||||
assert gpu_block_size_factor == 1
|
||||
|
||||
self.gpu_to_cpu_handler = SingleDirectionOffloadingHandler(
|
||||
src_tensors=gpu_tensors,
|
||||
dst_tensors=cpu_tensors,
|
||||
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||
src_block_size_factor=gpu_block_size_factor,
|
||||
dst_block_size_factor=cpu_block_size_factor,
|
||||
priority=1,
|
||||
)
|
||||
|
||||
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
|
||||
src_tensors=cpu_tensors,
|
||||
dst_tensors=gpu_tensors,
|
||||
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||
src_block_size_factor=cpu_block_size_factor,
|
||||
dst_block_size_factor=gpu_block_size_factor,
|
||||
priority=-1,
|
||||
)
|
||||
144
vllm/v1/kv_offload/worker/worker.py
Normal file
144
vllm/v1/kv_offload/worker/worker.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.kv_offload.abstract import LoadStoreSpec
|
||||
|
||||
# a single transfer spec (src_blocks_spec, dst_blocks_spec)
|
||||
TransferSpec = tuple[LoadStoreSpec, LoadStoreSpec]
|
||||
# transfers are forwarded to workers by (src_medium, dst_medium)
|
||||
TransferType = tuple[str, str]
|
||||
# transfer result (job_id, success)
|
||||
TransferResult = tuple[int, bool]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OffloadingHandler(ABC):
|
||||
"""
|
||||
OffloadingHandler class for managing asynchronous KV data transfers
|
||||
|
||||
This class runs in the worker.
|
||||
It kicks off async KV data transfer requests, and allows
|
||||
collecting back completion statuses.
|
||||
|
||||
The class provides the following primitives:
|
||||
transfer_async() - kicks off a new transfer job
|
||||
get_finished() - returns a list of newly finished job IDs.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
"""
|
||||
Initiates an asynchronous transfer of KV data.
|
||||
|
||||
Args:
|
||||
job_id: a unique ID that will be used when notifying back on
|
||||
transfer completion.
|
||||
spec: the (src, dst) spec of the KV data transfer.
|
||||
|
||||
Returns:
|
||||
True if transfer was submitted successfully.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
"""
|
||||
Get transfers finished since last call.
|
||||
|
||||
Returns:
|
||||
A list of (job_id, success) of transfers.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OffloadingWorker:
|
||||
"""
|
||||
OffloadingWorker class for managing asynchronous KV data transfers
|
||||
using multiple OffloadingHandlers
|
||||
|
||||
This class runs in the worker.
|
||||
It kicks off async KV data transfer requests, by delegating
|
||||
to one of its registered OffloadingHandlers, based on the transfer type.
|
||||
|
||||
The class provides the following primitives:
|
||||
register_handler() - registers a new handler to handle
|
||||
a specific transfer type
|
||||
transfer_async() - kicks off a new transfer job
|
||||
using one of the registered handlers.
|
||||
get_finished() - returns a list of newly finished job IDs
|
||||
from all handlers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.handlers: set[OffloadingHandler] = set()
|
||||
self.transfer_type_to_handler: dict[TransferType, OffloadingHandler] = {}
|
||||
|
||||
def register_handler(
|
||||
self,
|
||||
src_cls: type[LoadStoreSpec],
|
||||
dst_cls: type[LoadStoreSpec],
|
||||
handler: OffloadingHandler,
|
||||
) -> None:
|
||||
"""
|
||||
Registers a new handler.
|
||||
|
||||
Args:
|
||||
src_cls: the source type of transfers handled by this handler.
|
||||
dst_cls: the destination type of transfers handled by this handler.
|
||||
handler: the handler that will handle transfers.
|
||||
"""
|
||||
transfer_type = (src_cls.medium(), dst_cls.medium())
|
||||
assert transfer_type not in self.transfer_type_to_handler
|
||||
self.handlers.add(handler)
|
||||
self.transfer_type_to_handler[transfer_type] = handler
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
"""
|
||||
Initiates an asynchronous transfer of KV data.
|
||||
|
||||
Args:
|
||||
job_id: a unique ID that will be used when notifying back on
|
||||
transfer completion.
|
||||
spec: the (src, dst) spec of the KV data transfer.
|
||||
|
||||
Returns:
|
||||
True if transfer was submitted successfully.
|
||||
"""
|
||||
src, dst = spec
|
||||
transfer_type = (src.medium(), dst.medium())
|
||||
handler = self.transfer_type_to_handler.get(transfer_type)
|
||||
assert handler is not None
|
||||
|
||||
try:
|
||||
success = handler.transfer_async(job_id, spec)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Exception in %r transfer %d: %r",
|
||||
transfer_type,
|
||||
job_id,
|
||||
e,
|
||||
exc_info=True,
|
||||
)
|
||||
return False
|
||||
|
||||
if not success:
|
||||
logger.warning("Failed to submit %r transfer %d", transfer_type, job_id)
|
||||
else:
|
||||
logger.debug("Submitted %r transfer %d: %r", transfer_type, job_id, spec)
|
||||
|
||||
return success
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
"""
|
||||
Get transfers finished since last call.
|
||||
|
||||
Returns:
|
||||
A list of (job_id, success) of transfers.
|
||||
"""
|
||||
finished = []
|
||||
for handler in self.handlers:
|
||||
finished.extend(handler.get_finished())
|
||||
return finished
|
||||
Reference in New Issue
Block a user