739 lines
26 KiB
Python
739 lines
26 KiB
Python
import abc
|
|
import logging
|
|
import threading
|
|
from enum import IntEnum
|
|
from functools import wraps
|
|
|
|
import psutil
|
|
import torch
|
|
|
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
|
from sglang.srt.utils import is_npu
|
|
|
|
_is_npu = is_npu()
|
|
if not _is_npu:
|
|
from sgl_kernel.kvcacheio import (
|
|
transfer_kv_all_layer,
|
|
transfer_kv_all_layer_lf_pf,
|
|
transfer_kv_all_layer_mla,
|
|
transfer_kv_all_layer_mla_lf_pf,
|
|
transfer_kv_direct,
|
|
transfer_kv_per_layer,
|
|
transfer_kv_per_layer_mla,
|
|
transfer_kv_per_layer_mla_pf_lf,
|
|
transfer_kv_per_layer_pf_lf,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MemoryStateInt(IntEnum):
|
|
IDLE = 0
|
|
RESERVED = 1
|
|
PROTECTED = 2
|
|
SYNCED = 3
|
|
BACKUP = 4
|
|
|
|
|
|
def synchronized(debug_only=False):
|
|
def _decorator(func):
|
|
@wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
if (not debug_only) or self.debug:
|
|
with self.lock:
|
|
return func(self, *args, **kwargs)
|
|
else:
|
|
return True
|
|
|
|
return wrapper
|
|
|
|
return _decorator
|
|
|
|
|
|
class HostKVCache(abc.ABC):
|
|
|
|
def __init__(
|
|
self,
|
|
device_pool: KVCache,
|
|
host_to_device_ratio: float,
|
|
host_size: int,
|
|
page_size: int,
|
|
layout: str,
|
|
pin_memory: bool,
|
|
device: str,
|
|
):
|
|
self.device_pool = device_pool
|
|
self.page_size = page_size
|
|
self.layout = layout
|
|
self.pin_memory = pin_memory
|
|
self.device = device
|
|
|
|
self.dtype = device_pool.store_dtype
|
|
self.size_per_token = self.get_size_per_token()
|
|
if host_size > 0:
|
|
self.size = int(host_size * 1e9 // self.size_per_token)
|
|
else:
|
|
self.size = int(device_pool.size * host_to_device_ratio)
|
|
# Align the host memory pool size to the page size
|
|
self.size = self.size - (self.size % self.page_size)
|
|
self.start_layer = device_pool.start_layer
|
|
self.end_layer = device_pool.end_layer
|
|
|
|
assert (
|
|
self.size > device_pool.size
|
|
), "The host memory should be larger than the device memory with the current protocol"
|
|
|
|
# Verify there is enough available host memory.
|
|
host_mem = psutil.virtual_memory()
|
|
requested_bytes = self.size * self.size_per_token
|
|
# preserve at least 10GB for other usage
|
|
ten_gb = 10 * (1024**3)
|
|
available_bytes = host_mem.available - ten_gb
|
|
if requested_bytes > available_bytes:
|
|
raise ValueError(
|
|
f"Not enough host memory available. Requesting "
|
|
f"{requested_bytes / 1e9:.2f} GB but only have "
|
|
f"{available_bytes / 1e9:.2f} GB free. Please reduce the "
|
|
f"size of the hierarchical cache."
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
|
|
)
|
|
|
|
self.kv_buffer = self.init_kv_buffer()
|
|
|
|
# A lock for synchronized operations on memory allocation and state transitions.
|
|
self.lock = threading.RLock()
|
|
self.debug = logger.isEnabledFor(logging.DEBUG)
|
|
self.clear()
|
|
|
|
@abc.abstractmethod
|
|
def get_size_per_token(self):
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def init_kv_buffer(self):
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def load_to_device_per_layer(
|
|
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
|
) -> None:
|
|
"""
|
|
Load KV data from the host memory pool to the device memory pool for a specific layer.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def backup_from_device_all_layer(
|
|
self, device_pool, host_indices, device_indices, io_backend
|
|
) -> None:
|
|
"""
|
|
Backup KV data from the device memory pool to the host memory pool for all layers.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
|
"""
|
|
Get a flat data page from the host memory pool.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
|
"""
|
|
Get a dummy flat data page from the host memory pool.
|
|
This is used for prefetching or initializing empty pages.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
|
"""
|
|
Set a flat data page to the host memory pool.
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@synchronized()
|
|
def clear(self):
|
|
# Initialize memory states and tracking structures.
|
|
self.mem_state = torch.zeros(
|
|
(self.size,), dtype=torch.uint8, device=self.device
|
|
)
|
|
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
|
|
|
def available_size(self):
|
|
return len(self.free_slots)
|
|
|
|
@synchronized()
|
|
def alloc(self, need_size: int) -> torch.Tensor:
|
|
assert (
|
|
need_size % self.page_size == 0
|
|
), "The requested size should be a multiple of the page size."
|
|
if need_size > self.available_size():
|
|
return None
|
|
|
|
select_index = self.free_slots[:need_size]
|
|
self.free_slots = self.free_slots[need_size:]
|
|
|
|
if self.debug:
|
|
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
|
|
|
return select_index
|
|
|
|
@synchronized()
|
|
def free(self, indices: torch.Tensor) -> int:
|
|
self.free_slots = torch.cat([self.free_slots, indices])
|
|
if self.debug:
|
|
self.mem_state[indices] = MemoryStateInt.IDLE
|
|
return len(indices)
|
|
|
|
@synchronized(debug_only=True)
|
|
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
|
assert len(indices) > 0, "The indices should not be empty"
|
|
states = self.mem_state[indices]
|
|
assert (
|
|
states == states[0]
|
|
).all(), "The memory slots should have the same state {}".format(states)
|
|
return MemoryStateInt(states[0].item())
|
|
|
|
@synchronized(debug_only=True)
|
|
def is_reserved(self, indices: torch.Tensor) -> bool:
|
|
return self.get_state(indices) == MemoryStateInt.RESERVED
|
|
|
|
@synchronized(debug_only=True)
|
|
def is_protected(self, indices: torch.Tensor) -> bool:
|
|
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
|
|
|
@synchronized(debug_only=True)
|
|
def is_synced(self, indices: torch.Tensor) -> bool:
|
|
return self.get_state(indices) == MemoryStateInt.SYNCED
|
|
|
|
@synchronized(debug_only=True)
|
|
def is_backup(self, indices: torch.Tensor) -> bool:
|
|
return self.get_state(indices) == MemoryStateInt.BACKUP
|
|
|
|
@synchronized(debug_only=True)
|
|
def update_backup(self, indices: torch.Tensor):
|
|
if not self.is_synced(indices):
|
|
raise ValueError(
|
|
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
|
f"Current state: {self.get_state(indices)}"
|
|
)
|
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
|
|
|
@synchronized(debug_only=True)
|
|
def update_prefetch(self, indices: torch.Tensor):
|
|
if not self.is_reserved(indices):
|
|
raise ValueError(
|
|
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
|
|
f"Current state: {self.get_state(indices)}"
|
|
)
|
|
self.mem_state[indices] = MemoryStateInt.BACKUP
|
|
|
|
@synchronized(debug_only=True)
|
|
def update_synced(self, indices: torch.Tensor):
|
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
|
|
|
@synchronized(debug_only=True)
|
|
def protect_write(self, indices: torch.Tensor):
|
|
if not self.is_reserved(indices):
|
|
raise ValueError(
|
|
f"The host memory slots should be RESERVED before write operations. "
|
|
f"Current state: {self.get_state(indices)}"
|
|
)
|
|
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
|
|
|
@synchronized(debug_only=True)
|
|
def protect_load(self, indices: torch.Tensor):
|
|
if not self.is_backup(indices):
|
|
raise ValueError(
|
|
f"The host memory slots should be in BACKUP state before load operations. "
|
|
f"Current state: {self.get_state(indices)}"
|
|
)
|
|
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
|
|
|
@synchronized(debug_only=True)
|
|
def complete_io(self, indices: torch.Tensor):
|
|
if not self.is_protected(indices):
|
|
raise ValueError(
|
|
f"The host memory slots should be PROTECTED during I/O operations. "
|
|
f"Current state: {self.get_state(indices)}"
|
|
)
|
|
self.mem_state[indices] = MemoryStateInt.SYNCED
|
|
|
|
|
|
class MHATokenToKVPoolHost(HostKVCache):
|
|
device_pool: MHATokenToKVPool
|
|
|
|
def __init__(
|
|
self,
|
|
device_pool: MHATokenToKVPool,
|
|
host_to_device_ratio: float,
|
|
host_size: int,
|
|
page_size: int,
|
|
layout: str,
|
|
pin_memory: bool = True,
|
|
device: str = "cpu",
|
|
):
|
|
super().__init__(
|
|
device_pool,
|
|
host_to_device_ratio,
|
|
host_size,
|
|
page_size,
|
|
layout,
|
|
pin_memory,
|
|
device,
|
|
)
|
|
self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
|
|
self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
|
|
self.k_data_ptrs = torch.tensor(
|
|
[x.data_ptr() for x in self.k_data_refs],
|
|
dtype=torch.uint64,
|
|
device=self.device_pool.device,
|
|
)
|
|
self.v_data_ptrs = torch.tensor(
|
|
[x.data_ptr() for x in self.v_data_refs],
|
|
dtype=torch.uint64,
|
|
device=self.device_pool.device,
|
|
)
|
|
|
|
def get_size_per_token(self):
|
|
self.head_num = self.device_pool.head_num
|
|
self.head_dim = self.device_pool.head_dim
|
|
self.layer_num = self.device_pool.layer_num
|
|
|
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
|
|
|
def get_ksize_per_token(self):
|
|
return self.get_size_per_token() // 2
|
|
|
|
def init_kv_buffer(self):
|
|
if self.layout == "layer_first":
|
|
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
|
elif self.layout == "page_first":
|
|
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
|
self.layout_dim = self.token_stride_size * self.layer_num
|
|
return torch.empty(
|
|
dims,
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
)
|
|
|
|
@property
|
|
def k_buffer(self):
|
|
return self.kv_buffer[0]
|
|
|
|
@property
|
|
def v_buffer(self):
|
|
return self.kv_buffer[1]
|
|
|
|
def load_to_device_per_layer(
|
|
self,
|
|
device_pool,
|
|
host_indices,
|
|
device_indices,
|
|
layer_id,
|
|
io_backend,
|
|
):
|
|
if io_backend == "kernel":
|
|
if self.layout == "layer_first":
|
|
transfer_kv_per_layer(
|
|
src_k=self.k_buffer[layer_id],
|
|
dst_k=device_pool.k_buffer[layer_id],
|
|
src_v=self.v_buffer[layer_id],
|
|
dst_v=device_pool.v_buffer[layer_id],
|
|
src_indices=host_indices,
|
|
dst_indices=device_indices,
|
|
item_size=self.token_stride_size,
|
|
)
|
|
elif self.layout == "page_first":
|
|
transfer_kv_per_layer_pf_lf(
|
|
src_k=self.k_buffer,
|
|
dst_k=device_pool.k_buffer[layer_id],
|
|
src_v=self.v_buffer,
|
|
dst_v=device_pool.v_buffer[layer_id],
|
|
src_indices=host_indices,
|
|
dst_indices=device_indices,
|
|
layer_id=layer_id,
|
|
item_size=self.token_stride_size,
|
|
src_layout_dim=self.layout_dim,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
elif io_backend == "direct":
|
|
assert (
|
|
self.layout == "layer_first"
|
|
), f"Direct IO backend only supports layer_first layout."
|
|
transfer_kv_direct(
|
|
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
|
dst_layers=[
|
|
device_pool.k_buffer[layer_id],
|
|
device_pool.v_buffer[layer_id],
|
|
],
|
|
src_indices=host_indices,
|
|
dst_indices=device_indices,
|
|
page_size=self.page_size,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
|
|
|
def backup_from_device_all_layer(
|
|
self, device_pool, host_indices, device_indices, io_backend
|
|
):
|
|
if io_backend == "kernel":
|
|
if self.layout == "layer_first":
|
|
transfer_kv_all_layer(
|
|
src_k_layers=device_pool.k_data_ptrs,
|
|
dst_k_layers=self.k_data_ptrs,
|
|
src_v_layers=device_pool.v_data_ptrs,
|
|
dst_v_layers=self.v_data_ptrs,
|
|
src_indices=device_indices,
|
|
dst_indices=host_indices,
|
|
item_size=self.token_stride_size,
|
|
num_layers=self.layer_num,
|
|
)
|
|
elif self.layout == "page_first":
|
|
transfer_kv_all_layer_lf_pf(
|
|
src_k_layers=device_pool.k_data_ptrs,
|
|
dst_k=self.k_buffer,
|
|
src_v_layers=device_pool.v_data_ptrs,
|
|
dst_v=self.v_buffer,
|
|
src_indices=device_indices,
|
|
dst_indices=host_indices,
|
|
item_size=self.token_stride_size,
|
|
dst_layout_dim=self.layout_dim,
|
|
num_layers=self.layer_num,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
elif io_backend == "direct":
|
|
assert (
|
|
self.layout == "layer_first"
|
|
), f"Direct IO backend only supports layer_first layout."
|
|
transfer_kv_direct(
|
|
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
|
dst_layers=self.k_data_refs + self.v_data_refs,
|
|
src_indices=device_indices,
|
|
dst_indices=host_indices,
|
|
page_size=self.page_size,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
|
|
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
|
if self.layout == "layer_first":
|
|
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
|
elif self.layout == "page_first":
|
|
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
|
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
|
return torch.zeros(
|
|
(2, self.layer_num, self.page_size, self.head_num, self.head_dim),
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
).flatten()
|
|
|
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
|
if self.layout == "layer_first":
|
|
self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
|
|
data_page.reshape(
|
|
2,
|
|
self.layer_num,
|
|
self.page_size,
|
|
self.head_num,
|
|
self.head_dim,
|
|
)
|
|
)
|
|
elif self.layout == "page_first":
|
|
self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
|
|
data_page.reshape(
|
|
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
|
)
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
|
|
def get_buffer_meta(self, keys, indices):
|
|
ptr_list = []
|
|
key_list = []
|
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
|
v_offset = (
|
|
self.layer_num
|
|
* self.size
|
|
* self.head_num
|
|
* self.head_dim
|
|
* self.dtype.itemsize
|
|
)
|
|
for index in range(0, len(indices), self.page_size):
|
|
k_ptr = (
|
|
kv_buffer_data_ptr
|
|
+ indices[index]
|
|
* self.layer_num
|
|
* self.head_num
|
|
* self.head_dim
|
|
* self.dtype.itemsize
|
|
)
|
|
v_ptr = k_ptr + v_offset
|
|
ptr_list.append(k_ptr)
|
|
ptr_list.append(v_ptr)
|
|
key_ = keys[index // self.page_size]
|
|
key_list.append(f"{key_}_k")
|
|
key_list.append(f"{key_}_v")
|
|
element_size = (
|
|
self.layer_num
|
|
* self.dtype.itemsize
|
|
* self.page_size
|
|
* self.head_num
|
|
* self.head_dim
|
|
)
|
|
element_size_list = [element_size] * len(key_list)
|
|
return key_list, ptr_list, element_size_list
|
|
|
|
def get_buffer_with_hash(self, keys, indices):
|
|
assert self.layout == "page_first"
|
|
assert len(keys) == (len(indices) // self.page_size)
|
|
|
|
key_list = []
|
|
buf_list = []
|
|
|
|
for key, i in zip(keys, range(0, len(indices), self.page_size)):
|
|
key_list.append(f"{key}-k")
|
|
buf_list.append(self.k_buffer[i : i + self.page_size])
|
|
key_list.append(f"{key}-v")
|
|
buf_list.append(self.v_buffer[i : i + self.page_size])
|
|
|
|
return key_list, buf_list
|
|
|
|
|
|
class MLATokenToKVPoolHost(HostKVCache):
|
|
device_pool: MLATokenToKVPool
|
|
|
|
def __init__(
|
|
self,
|
|
device_pool: MLATokenToKVPool,
|
|
host_to_device_ratio: float,
|
|
host_size: int,
|
|
page_size: int,
|
|
layout: str,
|
|
pin_memory: bool = True,
|
|
device: str = "cpu",
|
|
):
|
|
super().__init__(
|
|
device_pool,
|
|
host_to_device_ratio,
|
|
host_size,
|
|
page_size,
|
|
layout,
|
|
pin_memory,
|
|
device,
|
|
)
|
|
self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
|
|
self.data_ptrs = torch.tensor(
|
|
[x.data_ptr() for x in self.data_refs],
|
|
dtype=torch.uint64,
|
|
device=self.device_pool.device,
|
|
)
|
|
|
|
def get_size_per_token(self):
|
|
self.kv_lora_rank = self.device_pool.kv_lora_rank
|
|
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
|
|
self.layer_num = self.device_pool.layer_num
|
|
|
|
return (
|
|
(self.kv_lora_rank + self.qk_rope_head_dim)
|
|
* 1
|
|
* self.dtype.itemsize
|
|
* self.layer_num
|
|
)
|
|
|
|
def get_ksize_per_token(self):
|
|
return self.get_size_per_token()
|
|
|
|
def init_kv_buffer(self):
|
|
if self.layout == "layer_first":
|
|
dims = (
|
|
self.layer_num,
|
|
self.size,
|
|
1,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
)
|
|
elif self.layout == "page_first":
|
|
dims = (
|
|
self.size,
|
|
self.layer_num,
|
|
1,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
self.token_stride_size = (
|
|
self.kv_lora_rank + self.qk_rope_head_dim
|
|
) * self.dtype.itemsize
|
|
self.layout_dim = self.token_stride_size * self.layer_num
|
|
|
|
return torch.empty(
|
|
dims,
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
)
|
|
|
|
def load_to_device_per_layer(
|
|
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
|
):
|
|
if io_backend == "kernel":
|
|
if self.layout == "layer_first":
|
|
transfer_kv_per_layer_mla(
|
|
src=self.kv_buffer[layer_id],
|
|
dst=device_pool.kv_buffer[layer_id],
|
|
src_indices=host_indices,
|
|
dst_indices=device_indices,
|
|
item_size=self.token_stride_size,
|
|
)
|
|
elif self.layout == "page_first":
|
|
transfer_kv_per_layer_mla_pf_lf(
|
|
src=self.kv_buffer,
|
|
dst=device_pool.kv_buffer[layer_id],
|
|
src_indices=host_indices,
|
|
dst_indices=device_indices,
|
|
layer_id=layer_id,
|
|
item_size=self.token_stride_size,
|
|
src_layout_dim=self.layout_dim,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
elif io_backend == "direct":
|
|
assert (
|
|
self.layout == "layer_first"
|
|
), f"Direct IO backend only supports layer_first layout."
|
|
transfer_kv_direct(
|
|
src_layers=[self.kv_buffer[layer_id]],
|
|
dst_layers=[device_pool.kv_buffer[layer_id]],
|
|
src_indices=host_indices,
|
|
dst_indices=device_indices,
|
|
page_size=self.page_size,
|
|
)
|
|
|
|
def backup_from_device_all_layer(
|
|
self, device_pool, host_indices, device_indices, io_backend
|
|
):
|
|
if io_backend == "kernel":
|
|
if self.layout == "layer_first":
|
|
transfer_kv_all_layer_mla(
|
|
src_layers=device_pool.data_ptrs,
|
|
dst_layers=self.data_ptrs,
|
|
src_indices=device_indices,
|
|
dst_indices=host_indices,
|
|
item_size=self.token_stride_size,
|
|
num_layers=self.layer_num,
|
|
)
|
|
elif self.layout == "page_first":
|
|
transfer_kv_all_layer_mla_lf_pf(
|
|
src_layers=device_pool.data_ptrs,
|
|
dst=self.kv_buffer,
|
|
src_indices=device_indices,
|
|
dst_indices=host_indices,
|
|
item_size=self.token_stride_size,
|
|
dst_layout_dim=self.layout_dim,
|
|
num_layers=self.layer_num,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
elif io_backend == "direct":
|
|
assert (
|
|
self.layout == "layer_first"
|
|
), f"Direct IO backend only supports layer_first layout."
|
|
transfer_kv_direct(
|
|
src_layers=device_pool.kv_buffer,
|
|
dst_layers=self.data_refs,
|
|
src_indices=device_indices,
|
|
dst_indices=host_indices,
|
|
page_size=self.page_size,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
|
|
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
|
if self.layout == "layer_first":
|
|
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
|
elif self.layout == "page_first":
|
|
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
|
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
|
return torch.zeros(
|
|
(
|
|
self.layer_num,
|
|
self.page_size,
|
|
1,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
),
|
|
dtype=self.dtype,
|
|
device=self.device,
|
|
pin_memory=self.pin_memory,
|
|
).flatten()
|
|
|
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
|
if self.layout == "layer_first":
|
|
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
|
self.layer_num,
|
|
self.page_size,
|
|
1,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
)
|
|
elif self.layout == "page_first":
|
|
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
|
|
self.page_size,
|
|
self.layer_num,
|
|
1,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
|
|
|
def get_buffer_meta(self, keys, indices):
|
|
ptr_list = []
|
|
key_list = []
|
|
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
|
|
for index in range(0, len(indices), self.page_size):
|
|
k_ptr = (
|
|
kv_buffer_data_ptr
|
|
+ indices[index]
|
|
* self.layer_num
|
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
|
* self.dtype.itemsize
|
|
)
|
|
ptr_list.append(k_ptr)
|
|
key_ = keys[index // self.page_size]
|
|
key_list.append(f"{key_}_k")
|
|
element_size = (
|
|
self.layer_num
|
|
* self.dtype.itemsize
|
|
* self.page_size
|
|
* (self.kv_lora_rank + self.qk_rope_head_dim)
|
|
)
|
|
element_size_list = [element_size] * len(key_list)
|
|
return key_list, ptr_list, element_size_list
|
|
|
|
def get_buffer_with_hash(self, keys, indices):
|
|
assert self.layout == "page_first"
|
|
assert len(keys) == (len(indices) // self.page_size)
|
|
|
|
buf_list = []
|
|
|
|
for i in range(0, len(indices), self.page_size):
|
|
buf_list.append(self.kv_buffer[i : i + self.page_size])
|
|
|
|
return keys, buf_list
|