Interface change for kvcache io to support page first layout (#8318)
This commit is contained in:
@@ -231,16 +231,7 @@ class HiCacheController:
|
||||
self.mem_pool_host = mem_pool_host
|
||||
self.write_policy = write_policy
|
||||
self.page_size = page_size
|
||||
# using kernel for small page KV cache transfer and DMA for large pages
|
||||
if not io_backend:
|
||||
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
||||
self.io_backend = (
|
||||
"direct"
|
||||
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
||||
else "kernel"
|
||||
)
|
||||
else:
|
||||
self.io_backend = io_backend
|
||||
self.io_backend = io_backend
|
||||
|
||||
self.enable_storage = False
|
||||
# todo: move backend initialization to storage backend module
|
||||
@@ -447,11 +438,8 @@ class HiCacheController:
|
||||
host_indices, device_indices = self.move_indices(
|
||||
operation.host_indices, operation.device_indices
|
||||
)
|
||||
self.mem_pool_device.backup_to_host_all_layer(
|
||||
self.mem_pool_host,
|
||||
host_indices,
|
||||
device_indices,
|
||||
self.io_backend,
|
||||
self.mem_pool_host.backup_from_device_all_layer(
|
||||
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
||||
)
|
||||
self.write_stream.synchronize()
|
||||
self.mem_pool_host.complete_io(operation.host_indices)
|
||||
@@ -491,8 +479,8 @@ class HiCacheController:
|
||||
batch_operation.host_indices, batch_operation.device_indices
|
||||
)
|
||||
for i in range(self.mem_pool_host.layer_num):
|
||||
self.mem_pool_device.load_from_host_per_layer(
|
||||
self.mem_pool_host,
|
||||
self.mem_pool_host.load_to_device_per_layer(
|
||||
self.mem_pool_device,
|
||||
host_indices,
|
||||
device_indices,
|
||||
i,
|
||||
|
||||
@@ -588,6 +588,7 @@ class Scheduler(
|
||||
== "fa3" # hot fix for incompatibility
|
||||
else server_args.hicache_io_backend
|
||||
),
|
||||
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||
)
|
||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||
|
||||
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
|
||||
hicache_size: int,
|
||||
hicache_write_policy: str,
|
||||
hicache_io_backend: str,
|
||||
hicache_mem_layout: str,
|
||||
hicache_storage_backend: Optional[str] = None,
|
||||
):
|
||||
|
||||
if hicache_io_backend == "direct":
|
||||
if hicache_mem_layout == "page_first":
|
||||
hicache_mem_layout = "layer_first"
|
||||
logger.warning(
|
||||
"Page first layout is not supported with direct IO backend, switching to layer first layout"
|
||||
)
|
||||
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||
self.kv_cache, hicache_ratio, hicache_size, page_size
|
||||
self.kv_cache,
|
||||
hicache_ratio,
|
||||
hicache_size,
|
||||
page_size,
|
||||
hicache_mem_layout,
|
||||
)
|
||||
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
||||
self.kv_cache, hicache_ratio, hicache_size, page_size
|
||||
self.kv_cache,
|
||||
hicache_ratio,
|
||||
hicache_size,
|
||||
page_size,
|
||||
hicache_mem_layout,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||
|
||||
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GB = 1024 * 1024 * 1024
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
||||
|
||||
|
||||
class ReqToTokenPool:
|
||||
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_from_host_per_layer(
|
||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
self.layer_transfer_counter = layer_transfer_counter
|
||||
|
||||
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.token_stride = self.head_num * self.head_dim
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||
|
||||
self.k_data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_buffer],
|
||||
dtype=torch.uint64,
|
||||
device=self.device,
|
||||
)
|
||||
self.v_data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.v_buffer],
|
||||
dtype=torch.uint64,
|
||||
device=self.device,
|
||||
)
|
||||
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
|
||||
self.data_strides = torch.tensor(
|
||||
[
|
||||
np.prod(x.shape[1:]) * x.dtype.itemsize
|
||||
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load_from_host_per_layer(
|
||||
self,
|
||||
host_pool,
|
||||
host_indices,
|
||||
device_indices,
|
||||
layer_id,
|
||||
io_backend,
|
||||
):
|
||||
transfer_kv_per_layer(
|
||||
src_k=host_pool.k_buffer[layer_id],
|
||||
dst_k=self.k_buffer[layer_id],
|
||||
src_v=host_pool.v_buffer[layer_id],
|
||||
dst_v=self.v_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
||||
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
||||
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
||||
raise ValueError(
|
||||
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
||||
)
|
||||
transfer_kv_per_layer(
|
||||
src_k=self.k_buffer[layer_id],
|
||||
dst_k=host_pool.k_buffer[layer_id],
|
||||
src_v=self.v_buffer[layer_id],
|
||||
dst_v=host_pool.v_buffer[layer_id],
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
|
||||
def _get_key_buffer(self, layer_id: int):
|
||||
# for internal use of referencing
|
||||
if self.store_dtype != self.dtype:
|
||||
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
|
||||
layer_id_override=layer_id_pool,
|
||||
)
|
||||
|
||||
def load_from_host_per_layer(
|
||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
||||
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
||||
|
||||
|
||||
class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
|
||||
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.token_stride = kv_lora_rank + qk_rope_head_dim
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.kv_buffer],
|
||||
dtype=torch.uint64,
|
||||
device=self.device,
|
||||
)
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
kv_size = self.get_kv_size_bytes()
|
||||
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
|
||||
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
||||
)
|
||||
|
||||
def load_from_host_per_layer(
|
||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
transfer_kv_per_layer_mla(
|
||||
src=host_pool.kv_buffer[layer_id],
|
||||
dst=self.kv_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
||||
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
||||
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
||||
raise ValueError(
|
||||
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
||||
)
|
||||
transfer_kv_per_layer_mla(
|
||||
src=self.kv_buffer[layer_id],
|
||||
dst=host_pool.kv_buffer[layer_id],
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
kv_cache_cpu = []
|
||||
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
||||
|
||||
def load_from_host_per_layer(
|
||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"HiCache not supported for DoubleSparseTokenToKVPool."
|
||||
)
|
||||
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"HiCache not supported for DoubleSparseTokenToKVPool."
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def copy_all_layer_kv_cache(
|
||||
|
||||
@@ -8,6 +8,21 @@ import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
||||
from sglang.srt.utils import is_npu
|
||||
|
||||
_is_npu = is_npu()
|
||||
if not _is_npu:
|
||||
from sgl_kernel.kvcacheio import (
|
||||
transfer_kv_all_layer,
|
||||
transfer_kv_all_layer_lf_pf,
|
||||
transfer_kv_all_layer_mla,
|
||||
transfer_kv_all_layer_mla_lf_pf,
|
||||
transfer_kv_direct,
|
||||
transfer_kv_per_layer,
|
||||
transfer_kv_per_layer_mla,
|
||||
transfer_kv_per_layer_mla_pf_lf,
|
||||
transfer_kv_per_layer_pf_lf,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
|
||||
device_pool: KVCache,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
layout: str,
|
||||
pin_memory: bool,
|
||||
device: str,
|
||||
page_size: int,
|
||||
):
|
||||
self.device_pool = device_pool
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.page_size = page_size
|
||||
self.layout = layout
|
||||
self.pin_memory = pin_memory
|
||||
self.device = device
|
||||
self.page_size = page_size
|
||||
|
||||
self.dtype = device_pool.store_dtype
|
||||
self.size_per_token = self.get_size_per_token()
|
||||
if host_size > 0:
|
||||
self.size = int(host_size * 1e9 // self.size_per_token)
|
||||
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
|
||||
def init_kv_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_to_device_per_layer(
|
||||
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
||||
) -> None:
|
||||
"""
|
||||
Load KV data from the host memory pool to the device memory pool for a specific layer.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def backup_from_device_all_layer(
|
||||
self, device_pool, host_indices, device_indices, io_backend
|
||||
) -> None:
|
||||
"""
|
||||
Backup KV data from the device memory pool to the host memory pool for all layers.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
"""
|
||||
@@ -238,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
layout: str,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
device_pool,
|
||||
host_to_device_ratio,
|
||||
host_size,
|
||||
page_size,
|
||||
layout,
|
||||
pin_memory,
|
||||
device,
|
||||
)
|
||||
self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
|
||||
self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
|
||||
self.k_data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_data_refs],
|
||||
dtype=torch.uint64,
|
||||
device=self.device_pool.device,
|
||||
)
|
||||
self.v_data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.v_data_refs],
|
||||
dtype=torch.uint64,
|
||||
device=self.device_pool.device,
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
@@ -253,16 +308,128 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||
|
||||
def init_kv_buffer(self):
|
||||
if self.layout == "layer_first":
|
||||
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
||||
elif self.layout == "page_first":
|
||||
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
||||
self.layout_dim = self.token_stride_size * self.layer_num
|
||||
return torch.empty(
|
||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
||||
dims,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
# todo, page first memory layout
|
||||
@property
|
||||
def k_buffer(self):
|
||||
return self.kv_buffer[0]
|
||||
|
||||
@property
|
||||
def v_buffer(self):
|
||||
return self.kv_buffer[1]
|
||||
|
||||
def load_to_device_per_layer(
|
||||
self,
|
||||
device_pool,
|
||||
host_indices,
|
||||
device_indices,
|
||||
layer_id,
|
||||
io_backend,
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
if self.layout == "layer_first":
|
||||
transfer_kv_per_layer(
|
||||
src_k=self.k_buffer[layer_id],
|
||||
dst_k=device_pool.k_buffer[layer_id],
|
||||
src_v=self.v_buffer[layer_id],
|
||||
dst_v=device_pool.v_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
item_size=self.token_stride_size,
|
||||
)
|
||||
elif self.layout == "page_first":
|
||||
transfer_kv_per_layer_pf_lf(
|
||||
src_k=self.k_buffer,
|
||||
dst_k=device_pool.k_buffer[layer_id],
|
||||
src_v=self.v_buffer,
|
||||
dst_v=device_pool.v_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
item_size=self.token_stride_size,
|
||||
src_layout_dim=self.layout_dim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
elif io_backend == "direct":
|
||||
assert (
|
||||
self.layout == "layer_first"
|
||||
), f"Direct IO backend only supports layer_first layout."
|
||||
transfer_kv_direct(
|
||||
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
||||
dst_layers=[
|
||||
device_pool.k_buffer[layer_id],
|
||||
device_pool.v_buffer[layer_id],
|
||||
],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
page_size=self.page_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||
|
||||
def backup_from_device_all_layer(
|
||||
self, device_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
if self.layout == "layer_first":
|
||||
transfer_kv_all_layer(
|
||||
src_k_layers=device_pool.k_data_ptrs,
|
||||
dst_k_layers=self.k_data_ptrs,
|
||||
src_v_layers=device_pool.v_data_ptrs,
|
||||
dst_v_layers=self.v_data_ptrs,
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
item_size=self.token_stride_size,
|
||||
num_layers=self.layer_num,
|
||||
)
|
||||
elif self.layout == "page_first":
|
||||
transfer_kv_all_layer_lf_pf(
|
||||
src_k_layers=device_pool.k_data_ptrs,
|
||||
dst_k=self.k_buffer,
|
||||
src_v_layers=device_pool.v_data_ptrs,
|
||||
dst_v=self.v_buffer,
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
item_size=self.token_stride_size,
|
||||
dst_layout_dim=self.layout_dim,
|
||||
num_layers=self.layer_num,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
elif io_backend == "direct":
|
||||
assert (
|
||||
self.layout == "layer_first"
|
||||
), f"Direct IO backend only supports layer_first layout."
|
||||
transfer_kv_direct(
|
||||
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
||||
dst_layers=self.k_data_refs + self.v_data_refs,
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
page_size=self.page_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||
if self.layout == "layer_first":
|
||||
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||
elif self.layout == "page_first":
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
|
||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
@@ -273,13 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
).flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
2,
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
)
|
||||
if self.layout == "layer_first":
|
||||
self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
|
||||
data_page.reshape(
|
||||
2,
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
)
|
||||
)
|
||||
elif self.layout == "page_first":
|
||||
self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
|
||||
data_page.reshape(
|
||||
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
|
||||
def get_buffer_meta(self, keys, indices):
|
||||
ptr_list = []
|
||||
@@ -318,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
element_size_list = [element_size] * len(key_list)
|
||||
return key_list, ptr_list, element_size_list
|
||||
|
||||
@property
|
||||
def k_buffer(self):
|
||||
return self.kv_buffer[0]
|
||||
|
||||
@property
|
||||
def v_buffer(self):
|
||||
return self.kv_buffer[1]
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MLATokenToKVPool
|
||||
@@ -336,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
page_size: int,
|
||||
layout: str,
|
||||
pin_memory: bool = True,
|
||||
device: str = "cpu",
|
||||
):
|
||||
super().__init__(
|
||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
||||
device_pool,
|
||||
host_to_device_ratio,
|
||||
host_size,
|
||||
page_size,
|
||||
layout,
|
||||
pin_memory,
|
||||
device,
|
||||
)
|
||||
self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.data_refs],
|
||||
dtype=torch.uint64,
|
||||
device=self.device_pool.device,
|
||||
)
|
||||
|
||||
def get_size_per_token(self):
|
||||
@@ -356,20 +539,115 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
)
|
||||
|
||||
def init_kv_buffer(self):
|
||||
return torch.empty(
|
||||
(
|
||||
if self.layout == "layer_first":
|
||||
dims = (
|
||||
self.layer_num,
|
||||
self.size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
)
|
||||
elif self.layout == "page_first":
|
||||
dims = (
|
||||
self.size,
|
||||
self.layer_num,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
self.token_stride_size = (
|
||||
self.kv_lora_rank + self.qk_rope_head_dim
|
||||
) * self.dtype.itemsize
|
||||
self.layout_dim = self.token_stride_size * self.layer_num
|
||||
|
||||
return torch.empty(
|
||||
dims,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
def load_to_device_per_layer(
|
||||
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
if self.layout == "layer_first":
|
||||
transfer_kv_per_layer_mla(
|
||||
src=self.kv_buffer[layer_id],
|
||||
dst=device_pool.kv_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
item_size=self.token_stride_size,
|
||||
)
|
||||
elif self.layout == "page_first":
|
||||
transfer_kv_per_layer_mla_pf_lf(
|
||||
src=self.kv_buffer,
|
||||
dst=device_pool.kv_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
item_size=self.token_stride_size,
|
||||
src_layout_dim=self.layout_dim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
elif io_backend == "direct":
|
||||
assert (
|
||||
self.layout == "layer_first"
|
||||
), f"Direct IO backend only supports layer_first layout."
|
||||
transfer_kv_direct(
|
||||
src_layers=[self.kv_buffer[layer_id]],
|
||||
dst_layers=[device_pool.kv_buffer[layer_id]],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
page_size=self.page_size,
|
||||
)
|
||||
|
||||
def backup_from_device_all_layer(
|
||||
self, device_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
if io_backend == "kernel":
|
||||
if self.layout == "layer_first":
|
||||
transfer_kv_all_layer_mla(
|
||||
src_layers=device_pool.data_ptrs,
|
||||
dst_layers=self.data_ptrs,
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
item_size=self.token_stride_size,
|
||||
num_layers=self.layer_num,
|
||||
)
|
||||
elif self.layout == "page_first":
|
||||
transfer_kv_all_layer_mla_lf_pf(
|
||||
src_layers=device_pool.data_ptrs,
|
||||
dst_k=self.kv_buffer,
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
item_size=self.token_stride_size,
|
||||
dst_layout_dim=self.layout_dim,
|
||||
num_layers=self.layer_num,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
elif io_backend == "direct":
|
||||
assert (
|
||||
self.layout == "layer_first"
|
||||
), f"Direct IO backend only supports layer_first layout."
|
||||
transfer_kv_direct(
|
||||
src_layers=device_pool.kv_buffer,
|
||||
dst_layers=self.data_refs,
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
page_size=self.page_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||
|
||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||
if self.layout == "layer_first":
|
||||
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||
elif self.layout == "page_first":
|
||||
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
|
||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||
return torch.zeros(
|
||||
@@ -385,12 +663,22 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
).flatten()
|
||||
|
||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
if self.layout == "layer_first":
|
||||
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
||||
self.layer_num,
|
||||
self.page_size,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
elif self.layout == "page_first":
|
||||
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
|
||||
self.page_size,
|
||||
self.layer_num,
|
||||
1,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||
|
||||
def get_buffer_meta(self, keys, indices):
|
||||
ptr_list = []
|
||||
|
||||
@@ -198,7 +198,8 @@ class ServerArgs:
|
||||
hicache_ratio: float = 2.0
|
||||
hicache_size: int = 0
|
||||
hicache_write_policy: str = "write_through_selective"
|
||||
hicache_io_backend: str = ""
|
||||
hicache_io_backend: str = "kernel"
|
||||
hicache_mem_layout: str = "layer_first"
|
||||
hicache_storage_backend: Optional[str] = None
|
||||
|
||||
# Double Sparsity
|
||||
@@ -1487,6 +1488,14 @@ class ServerArgs:
|
||||
default=ServerArgs.hicache_io_backend,
|
||||
help="The IO backend for KV cache transfer between CPU and GPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hicache-mem-layout",
|
||||
type=str,
|
||||
choices=["layer_first", "page_first"],
|
||||
default=ServerArgs.hicache_mem_layout,
|
||||
help="The layout of host memory pool for hierarchical cache.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hicache-storage-backend",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user