Interface change for kvcache io to support page first layout (#8318)
This commit is contained in:
@@ -231,16 +231,7 @@ class HiCacheController:
|
|||||||
self.mem_pool_host = mem_pool_host
|
self.mem_pool_host = mem_pool_host
|
||||||
self.write_policy = write_policy
|
self.write_policy = write_policy
|
||||||
self.page_size = page_size
|
self.page_size = page_size
|
||||||
# using kernel for small page KV cache transfer and DMA for large pages
|
self.io_backend = io_backend
|
||||||
if not io_backend:
|
|
||||||
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
|
|
||||||
self.io_backend = (
|
|
||||||
"direct"
|
|
||||||
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
|
|
||||||
else "kernel"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.io_backend = io_backend
|
|
||||||
|
|
||||||
self.enable_storage = False
|
self.enable_storage = False
|
||||||
# todo: move backend initialization to storage backend module
|
# todo: move backend initialization to storage backend module
|
||||||
@@ -447,11 +438,8 @@ class HiCacheController:
|
|||||||
host_indices, device_indices = self.move_indices(
|
host_indices, device_indices = self.move_indices(
|
||||||
operation.host_indices, operation.device_indices
|
operation.host_indices, operation.device_indices
|
||||||
)
|
)
|
||||||
self.mem_pool_device.backup_to_host_all_layer(
|
self.mem_pool_host.backup_from_device_all_layer(
|
||||||
self.mem_pool_host,
|
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
||||||
host_indices,
|
|
||||||
device_indices,
|
|
||||||
self.io_backend,
|
|
||||||
)
|
)
|
||||||
self.write_stream.synchronize()
|
self.write_stream.synchronize()
|
||||||
self.mem_pool_host.complete_io(operation.host_indices)
|
self.mem_pool_host.complete_io(operation.host_indices)
|
||||||
@@ -491,8 +479,8 @@ class HiCacheController:
|
|||||||
batch_operation.host_indices, batch_operation.device_indices
|
batch_operation.host_indices, batch_operation.device_indices
|
||||||
)
|
)
|
||||||
for i in range(self.mem_pool_host.layer_num):
|
for i in range(self.mem_pool_host.layer_num):
|
||||||
self.mem_pool_device.load_from_host_per_layer(
|
self.mem_pool_host.load_to_device_per_layer(
|
||||||
self.mem_pool_host,
|
self.mem_pool_device,
|
||||||
host_indices,
|
host_indices,
|
||||||
device_indices,
|
device_indices,
|
||||||
i,
|
i,
|
||||||
|
|||||||
@@ -588,6 +588,7 @@ class Scheduler(
|
|||||||
== "fa3" # hot fix for incompatibility
|
== "fa3" # hot fix for incompatibility
|
||||||
else server_args.hicache_io_backend
|
else server_args.hicache_io_backend
|
||||||
),
|
),
|
||||||
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
||||||
hicache_storage_backend=server_args.hicache_storage_backend,
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
||||||
)
|
)
|
||||||
self.tp_worker.register_hicache_layer_transfer_counter(
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
||||||
|
|||||||
@@ -35,16 +35,33 @@ class HiRadixCache(RadixCache):
|
|||||||
hicache_size: int,
|
hicache_size: int,
|
||||||
hicache_write_policy: str,
|
hicache_write_policy: str,
|
||||||
hicache_io_backend: str,
|
hicache_io_backend: str,
|
||||||
|
hicache_mem_layout: str,
|
||||||
hicache_storage_backend: Optional[str] = None,
|
hicache_storage_backend: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
if hicache_io_backend == "direct":
|
||||||
|
if hicache_mem_layout == "page_first":
|
||||||
|
hicache_mem_layout = "layer_first"
|
||||||
|
logger.warning(
|
||||||
|
"Page first layout is not supported with direct IO backend, switching to layer first layout"
|
||||||
|
)
|
||||||
|
|
||||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
||||||
self.kv_cache, hicache_ratio, hicache_size, page_size
|
self.kv_cache,
|
||||||
|
hicache_ratio,
|
||||||
|
hicache_size,
|
||||||
|
page_size,
|
||||||
|
hicache_mem_layout,
|
||||||
)
|
)
|
||||||
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
||||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
||||||
self.kv_cache, hicache_ratio, hicache_size, page_size
|
self.kv_cache,
|
||||||
|
hicache_ratio,
|
||||||
|
hicache_size,
|
||||||
|
page_size,
|
||||||
|
hicache_mem_layout,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
|
||||||
|
|||||||
@@ -31,21 +31,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GB = 1024 * 1024 * 1024
|
GB = 1024 * 1024 * 1024
|
||||||
_is_cuda = is_cuda()
|
_is_cuda = is_cuda()
|
||||||
_is_npu = is_npu()
|
|
||||||
if not _is_npu:
|
|
||||||
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
|
||||||
|
|
||||||
|
|
||||||
class ReqToTokenPool:
|
class ReqToTokenPool:
|
||||||
@@ -153,18 +149,6 @@ class KVCache(abc.ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def load_from_host_per_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
|
||||||
):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def backup_to_host_all_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, io_backend
|
|
||||||
):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||||
self.layer_transfer_counter = layer_transfer_counter
|
self.layer_transfer_counter = layer_transfer_counter
|
||||||
|
|
||||||
@@ -253,12 +237,18 @@ class MHATokenToKVPool(KVCache):
|
|||||||
)
|
)
|
||||||
for _ in range(self.layer_num)
|
for _ in range(self.layer_num)
|
||||||
]
|
]
|
||||||
self.token_stride = self.head_num * self.head_dim
|
|
||||||
self.data_ptrs = torch.tensor(
|
self.k_data_ptrs = torch.tensor(
|
||||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
[x.data_ptr() for x in self.k_buffer],
|
||||||
dtype=torch.uint64,
|
dtype=torch.uint64,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
self.v_data_ptrs = torch.tensor(
|
||||||
|
[x.data_ptr() for x in self.v_buffer],
|
||||||
|
dtype=torch.uint64,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
|
||||||
self.data_strides = torch.tensor(
|
self.data_strides = torch.tensor(
|
||||||
[
|
[
|
||||||
np.prod(x.shape[1:]) * x.dtype.itemsize
|
np.prod(x.shape[1:]) * x.dtype.itemsize
|
||||||
@@ -347,47 +337,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
def load_from_host_per_layer(
|
|
||||||
self,
|
|
||||||
host_pool,
|
|
||||||
host_indices,
|
|
||||||
device_indices,
|
|
||||||
layer_id,
|
|
||||||
io_backend,
|
|
||||||
):
|
|
||||||
transfer_kv_per_layer(
|
|
||||||
src_k=host_pool.k_buffer[layer_id],
|
|
||||||
dst_k=self.k_buffer[layer_id],
|
|
||||||
src_v=host_pool.v_buffer[layer_id],
|
|
||||||
dst_v=self.v_buffer[layer_id],
|
|
||||||
src_indices=host_indices,
|
|
||||||
dst_indices=device_indices,
|
|
||||||
io_backend=io_backend,
|
|
||||||
page_size=self.page_size,
|
|
||||||
item_size=self.token_stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
def backup_to_host_all_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, io_backend
|
|
||||||
):
|
|
||||||
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
|
||||||
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
|
||||||
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
|
||||||
raise ValueError(
|
|
||||||
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
|
||||||
)
|
|
||||||
transfer_kv_per_layer(
|
|
||||||
src_k=self.k_buffer[layer_id],
|
|
||||||
dst_k=host_pool.k_buffer[layer_id],
|
|
||||||
src_v=self.v_buffer[layer_id],
|
|
||||||
dst_v=host_pool.v_buffer[layer_id],
|
|
||||||
src_indices=device_indices,
|
|
||||||
dst_indices=host_indices,
|
|
||||||
io_backend=io_backend,
|
|
||||||
page_size=self.page_size,
|
|
||||||
item_size=self.token_stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_key_buffer(self, layer_id: int):
|
def _get_key_buffer(self, layer_id: int):
|
||||||
# for internal use of referencing
|
# for internal use of referencing
|
||||||
if self.store_dtype != self.dtype:
|
if self.store_dtype != self.dtype:
|
||||||
@@ -602,16 +551,6 @@ class SWAKVPool(KVCache):
|
|||||||
layer_id_override=layer_id_pool,
|
layer_id_override=layer_id_pool,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_from_host_per_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
|
||||||
):
|
|
||||||
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
|
||||||
|
|
||||||
def backup_to_host_all_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, io_backend
|
|
||||||
):
|
|
||||||
raise NotImplementedError("HiCache not supported for SWAKVPool.")
|
|
||||||
|
|
||||||
|
|
||||||
class AscendTokenToKVPool(MHATokenToKVPool):
|
class AscendTokenToKVPool(MHATokenToKVPool):
|
||||||
|
|
||||||
@@ -823,7 +762,11 @@ class MLATokenToKVPool(KVCache):
|
|||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.token_stride = kv_lora_rank + qk_rope_head_dim
|
self.data_ptrs = torch.tensor(
|
||||||
|
[x.data_ptr() for x in self.kv_buffer],
|
||||||
|
dtype=torch.uint64,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
|
|
||||||
kv_size = self.get_kv_size_bytes()
|
kv_size = self.get_kv_size_bytes()
|
||||||
@@ -909,38 +852,6 @@ class MLATokenToKVPool(KVCache):
|
|||||||
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_from_host_per_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
|
||||||
):
|
|
||||||
transfer_kv_per_layer_mla(
|
|
||||||
src=host_pool.kv_buffer[layer_id],
|
|
||||||
dst=self.kv_buffer[layer_id],
|
|
||||||
src_indices=host_indices,
|
|
||||||
dst_indices=device_indices,
|
|
||||||
io_backend=io_backend,
|
|
||||||
page_size=self.page_size,
|
|
||||||
item_size=self.token_stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
def backup_to_host_all_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, io_backend
|
|
||||||
):
|
|
||||||
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
|
||||||
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
|
||||||
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
|
||||||
raise ValueError(
|
|
||||||
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
|
||||||
)
|
|
||||||
transfer_kv_per_layer_mla(
|
|
||||||
src=self.kv_buffer[layer_id],
|
|
||||||
dst=host_pool.kv_buffer[layer_id],
|
|
||||||
src_indices=device_indices,
|
|
||||||
dst_indices=host_indices,
|
|
||||||
io_backend=io_backend,
|
|
||||||
page_size=self.page_size,
|
|
||||||
item_size=self.token_stride,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_cpu_copy(self, indices):
|
def get_cpu_copy(self, indices):
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
kv_cache_cpu = []
|
kv_cache_cpu = []
|
||||||
@@ -1131,20 +1042,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
||||||
|
|
||||||
def load_from_host_per_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
|
||||||
):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"HiCache not supported for DoubleSparseTokenToKVPool."
|
|
||||||
)
|
|
||||||
|
|
||||||
def backup_to_host_all_layer(
|
|
||||||
self, host_pool, host_indices, device_indices, io_backend
|
|
||||||
):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"HiCache not supported for DoubleSparseTokenToKVPool."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def copy_all_layer_kv_cache(
|
def copy_all_layer_kv_cache(
|
||||||
|
|||||||
@@ -8,6 +8,21 @@ import psutil
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
||||||
|
from sglang.srt.utils import is_npu
|
||||||
|
|
||||||
|
_is_npu = is_npu()
|
||||||
|
if not _is_npu:
|
||||||
|
from sgl_kernel.kvcacheio import (
|
||||||
|
transfer_kv_all_layer,
|
||||||
|
transfer_kv_all_layer_lf_pf,
|
||||||
|
transfer_kv_all_layer_mla,
|
||||||
|
transfer_kv_all_layer_mla_lf_pf,
|
||||||
|
transfer_kv_direct,
|
||||||
|
transfer_kv_per_layer,
|
||||||
|
transfer_kv_per_layer_mla,
|
||||||
|
transfer_kv_per_layer_mla_pf_lf,
|
||||||
|
transfer_kv_per_layer_pf_lf,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,15 +57,18 @@ class HostKVCache(abc.ABC):
|
|||||||
device_pool: KVCache,
|
device_pool: KVCache,
|
||||||
host_to_device_ratio: float,
|
host_to_device_ratio: float,
|
||||||
host_size: int,
|
host_size: int,
|
||||||
|
page_size: int,
|
||||||
|
layout: str,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
device: str,
|
device: str,
|
||||||
page_size: int,
|
|
||||||
):
|
):
|
||||||
self.device_pool = device_pool
|
self.device_pool = device_pool
|
||||||
self.dtype = device_pool.store_dtype
|
self.page_size = page_size
|
||||||
|
self.layout = layout
|
||||||
self.pin_memory = pin_memory
|
self.pin_memory = pin_memory
|
||||||
self.device = device
|
self.device = device
|
||||||
self.page_size = page_size
|
|
||||||
|
self.dtype = device_pool.store_dtype
|
||||||
self.size_per_token = self.get_size_per_token()
|
self.size_per_token = self.get_size_per_token()
|
||||||
if host_size > 0:
|
if host_size > 0:
|
||||||
self.size = int(host_size * 1e9 // self.size_per_token)
|
self.size = int(host_size * 1e9 // self.size_per_token)
|
||||||
@@ -98,6 +116,24 @@ class HostKVCache(abc.ABC):
|
|||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def load_to_device_per_layer(
|
||||||
|
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Load KV data from the host memory pool to the device memory pool for a specific layer.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def backup_from_device_all_layer(
|
||||||
|
self, device_pool, host_indices, device_indices, io_backend
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Backup KV data from the device memory pool to the host memory pool for all layers.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
@@ -238,11 +274,30 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
host_to_device_ratio: float,
|
host_to_device_ratio: float,
|
||||||
host_size: int,
|
host_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
layout: str,
|
||||||
pin_memory: bool = True,
|
pin_memory: bool = True,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
device_pool,
|
||||||
|
host_to_device_ratio,
|
||||||
|
host_size,
|
||||||
|
page_size,
|
||||||
|
layout,
|
||||||
|
pin_memory,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
self.k_data_refs = [self.k_buffer[i] for i in range(self.layer_num)]
|
||||||
|
self.v_data_refs = [self.v_buffer[i] for i in range(self.layer_num)]
|
||||||
|
self.k_data_ptrs = torch.tensor(
|
||||||
|
[x.data_ptr() for x in self.k_data_refs],
|
||||||
|
dtype=torch.uint64,
|
||||||
|
device=self.device_pool.device,
|
||||||
|
)
|
||||||
|
self.v_data_ptrs = torch.tensor(
|
||||||
|
[x.data_ptr() for x in self.v_data_refs],
|
||||||
|
dtype=torch.uint64,
|
||||||
|
device=self.device_pool.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_size_per_token(self):
|
def get_size_per_token(self):
|
||||||
@@ -253,16 +308,128 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
return self.head_dim * self.head_num * self.layer_num * self.dtype.itemsize * 2
|
||||||
|
|
||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
|
if self.layout == "layer_first":
|
||||||
|
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
||||||
|
self.layout_dim = self.token_stride_size * self.layer_num
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
(2, self.layer_num, self.size, self.head_num, self.head_dim),
|
dims,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# todo, page first memory layout
|
@property
|
||||||
|
def k_buffer(self):
|
||||||
|
return self.kv_buffer[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def v_buffer(self):
|
||||||
|
return self.kv_buffer[1]
|
||||||
|
|
||||||
|
def load_to_device_per_layer(
|
||||||
|
self,
|
||||||
|
device_pool,
|
||||||
|
host_indices,
|
||||||
|
device_indices,
|
||||||
|
layer_id,
|
||||||
|
io_backend,
|
||||||
|
):
|
||||||
|
if io_backend == "kernel":
|
||||||
|
if self.layout == "layer_first":
|
||||||
|
transfer_kv_per_layer(
|
||||||
|
src_k=self.k_buffer[layer_id],
|
||||||
|
dst_k=device_pool.k_buffer[layer_id],
|
||||||
|
src_v=self.v_buffer[layer_id],
|
||||||
|
dst_v=device_pool.v_buffer[layer_id],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
transfer_kv_per_layer_pf_lf(
|
||||||
|
src_k=self.k_buffer,
|
||||||
|
dst_k=device_pool.k_buffer[layer_id],
|
||||||
|
src_v=self.v_buffer,
|
||||||
|
dst_v=device_pool.v_buffer[layer_id],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
src_layout_dim=self.layout_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
elif io_backend == "direct":
|
||||||
|
assert (
|
||||||
|
self.layout == "layer_first"
|
||||||
|
), f"Direct IO backend only supports layer_first layout."
|
||||||
|
transfer_kv_direct(
|
||||||
|
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
||||||
|
dst_layers=[
|
||||||
|
device_pool.k_buffer[layer_id],
|
||||||
|
device_pool.v_buffer[layer_id],
|
||||||
|
],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||||
|
|
||||||
|
def backup_from_device_all_layer(
|
||||||
|
self, device_pool, host_indices, device_indices, io_backend
|
||||||
|
):
|
||||||
|
if io_backend == "kernel":
|
||||||
|
if self.layout == "layer_first":
|
||||||
|
transfer_kv_all_layer(
|
||||||
|
src_k_layers=device_pool.k_data_ptrs,
|
||||||
|
dst_k_layers=self.k_data_ptrs,
|
||||||
|
src_v_layers=device_pool.v_data_ptrs,
|
||||||
|
dst_v_layers=self.v_data_ptrs,
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
num_layers=self.layer_num,
|
||||||
|
)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
transfer_kv_all_layer_lf_pf(
|
||||||
|
src_k_layers=device_pool.k_data_ptrs,
|
||||||
|
dst_k=self.k_buffer,
|
||||||
|
src_v_layers=device_pool.v_data_ptrs,
|
||||||
|
dst_v=self.v_buffer,
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
dst_layout_dim=self.layout_dim,
|
||||||
|
num_layers=self.layer_num,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
elif io_backend == "direct":
|
||||||
|
assert (
|
||||||
|
self.layout == "layer_first"
|
||||||
|
), f"Direct IO backend only supports layer_first layout."
|
||||||
|
transfer_kv_direct(
|
||||||
|
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
||||||
|
dst_layers=self.k_data_refs + self.v_data_refs,
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||||
|
|
||||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||||
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
if self.layout == "layer_first":
|
||||||
|
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
|
||||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||||
return torch.zeros(
|
return torch.zeros(
|
||||||
@@ -273,13 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
).flatten()
|
).flatten()
|
||||||
|
|
||||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||||
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
|
if self.layout == "layer_first":
|
||||||
2,
|
self.kv_buffer[:, :, index : index + self.page_size, :, :] = (
|
||||||
self.layer_num,
|
data_page.reshape(
|
||||||
self.page_size,
|
2,
|
||||||
self.head_num,
|
self.layer_num,
|
||||||
self.head_dim,
|
self.page_size,
|
||||||
)
|
self.head_num,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
self.kv_buffer[:, index : index + self.page_size, :, :, :] = (
|
||||||
|
data_page.reshape(
|
||||||
|
2, self.page_size, self.layer_num, self.head_num, self.head_dim
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
|
||||||
def get_buffer_meta(self, keys, indices):
|
def get_buffer_meta(self, keys, indices):
|
||||||
ptr_list = []
|
ptr_list = []
|
||||||
@@ -318,14 +496,6 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
element_size_list = [element_size] * len(key_list)
|
element_size_list = [element_size] * len(key_list)
|
||||||
return key_list, ptr_list, element_size_list
|
return key_list, ptr_list, element_size_list
|
||||||
|
|
||||||
@property
|
|
||||||
def k_buffer(self):
|
|
||||||
return self.kv_buffer[0]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def v_buffer(self):
|
|
||||||
return self.kv_buffer[1]
|
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPoolHost(HostKVCache):
|
class MLATokenToKVPoolHost(HostKVCache):
|
||||||
device_pool: MLATokenToKVPool
|
device_pool: MLATokenToKVPool
|
||||||
@@ -336,11 +506,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
host_to_device_ratio: float,
|
host_to_device_ratio: float,
|
||||||
host_size: int,
|
host_size: int,
|
||||||
page_size: int,
|
page_size: int,
|
||||||
|
layout: str,
|
||||||
pin_memory: bool = True,
|
pin_memory: bool = True,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
device_pool,
|
||||||
|
host_to_device_ratio,
|
||||||
|
host_size,
|
||||||
|
page_size,
|
||||||
|
layout,
|
||||||
|
pin_memory,
|
||||||
|
device,
|
||||||
|
)
|
||||||
|
self.data_refs = [self.kv_buffer[i] for i in range(self.layer_num)]
|
||||||
|
self.data_ptrs = torch.tensor(
|
||||||
|
[x.data_ptr() for x in self.data_refs],
|
||||||
|
dtype=torch.uint64,
|
||||||
|
device=self.device_pool.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_size_per_token(self):
|
def get_size_per_token(self):
|
||||||
@@ -356,20 +539,115 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def init_kv_buffer(self):
|
def init_kv_buffer(self):
|
||||||
return torch.empty(
|
if self.layout == "layer_first":
|
||||||
(
|
dims = (
|
||||||
self.layer_num,
|
self.layer_num,
|
||||||
self.size,
|
self.size,
|
||||||
1,
|
1,
|
||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
),
|
)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
dims = (
|
||||||
|
self.size,
|
||||||
|
self.layer_num,
|
||||||
|
1,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
self.token_stride_size = (
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim
|
||||||
|
) * self.dtype.itemsize
|
||||||
|
self.layout_dim = self.token_stride_size * self.layer_num
|
||||||
|
|
||||||
|
return torch.empty(
|
||||||
|
dims,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
pin_memory=self.pin_memory,
|
pin_memory=self.pin_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_to_device_per_layer(
|
||||||
|
self, device_pool, host_indices, device_indices, layer_id, io_backend
|
||||||
|
):
|
||||||
|
if io_backend == "kernel":
|
||||||
|
if self.layout == "layer_first":
|
||||||
|
transfer_kv_per_layer_mla(
|
||||||
|
src=self.kv_buffer[layer_id],
|
||||||
|
dst=device_pool.kv_buffer[layer_id],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
transfer_kv_per_layer_mla_pf_lf(
|
||||||
|
src=self.kv_buffer,
|
||||||
|
dst=device_pool.kv_buffer[layer_id],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
src_layout_dim=self.layout_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
elif io_backend == "direct":
|
||||||
|
assert (
|
||||||
|
self.layout == "layer_first"
|
||||||
|
), f"Direct IO backend only supports layer_first layout."
|
||||||
|
transfer_kv_direct(
|
||||||
|
src_layers=[self.kv_buffer[layer_id]],
|
||||||
|
dst_layers=[device_pool.kv_buffer[layer_id]],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def backup_from_device_all_layer(
|
||||||
|
self, device_pool, host_indices, device_indices, io_backend
|
||||||
|
):
|
||||||
|
if io_backend == "kernel":
|
||||||
|
if self.layout == "layer_first":
|
||||||
|
transfer_kv_all_layer_mla(
|
||||||
|
src_layers=device_pool.data_ptrs,
|
||||||
|
dst_layers=self.data_ptrs,
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
num_layers=self.layer_num,
|
||||||
|
)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
transfer_kv_all_layer_mla_lf_pf(
|
||||||
|
src_layers=device_pool.data_ptrs,
|
||||||
|
dst_k=self.kv_buffer,
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
item_size=self.token_stride_size,
|
||||||
|
dst_layout_dim=self.layout_dim,
|
||||||
|
num_layers=self.layer_num,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
elif io_backend == "direct":
|
||||||
|
assert (
|
||||||
|
self.layout == "layer_first"
|
||||||
|
), f"Direct IO backend only supports layer_first layout."
|
||||||
|
transfer_kv_direct(
|
||||||
|
src_layers=device_pool.kv_buffer,
|
||||||
|
dst_layers=self.data_refs,
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||||
|
|
||||||
def get_flat_data_page(self, index) -> torch.Tensor:
|
def get_flat_data_page(self, index) -> torch.Tensor:
|
||||||
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
if self.layout == "layer_first":
|
||||||
|
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
|
||||||
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
def get_dummy_flat_data_page(self) -> torch.Tensor:
|
||||||
return torch.zeros(
|
return torch.zeros(
|
||||||
@@ -385,12 +663,22 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
).flatten()
|
).flatten()
|
||||||
|
|
||||||
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
|
||||||
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
if self.layout == "layer_first":
|
||||||
self.layer_num,
|
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
|
||||||
self.page_size,
|
self.layer_num,
|
||||||
1,
|
self.page_size,
|
||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
1,
|
||||||
)
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
elif self.layout == "page_first":
|
||||||
|
self.kv_buffer[index : index + self.page_size, :, :, :] = data_page.reshape(
|
||||||
|
self.page_size,
|
||||||
|
self.layer_num,
|
||||||
|
1,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
|
||||||
def get_buffer_meta(self, keys, indices):
|
def get_buffer_meta(self, keys, indices):
|
||||||
ptr_list = []
|
ptr_list = []
|
||||||
|
|||||||
@@ -198,7 +198,8 @@ class ServerArgs:
|
|||||||
hicache_ratio: float = 2.0
|
hicache_ratio: float = 2.0
|
||||||
hicache_size: int = 0
|
hicache_size: int = 0
|
||||||
hicache_write_policy: str = "write_through_selective"
|
hicache_write_policy: str = "write_through_selective"
|
||||||
hicache_io_backend: str = ""
|
hicache_io_backend: str = "kernel"
|
||||||
|
hicache_mem_layout: str = "layer_first"
|
||||||
hicache_storage_backend: Optional[str] = None
|
hicache_storage_backend: Optional[str] = None
|
||||||
|
|
||||||
# Double Sparsity
|
# Double Sparsity
|
||||||
@@ -1487,6 +1488,14 @@ class ServerArgs:
|
|||||||
default=ServerArgs.hicache_io_backend,
|
default=ServerArgs.hicache_io_backend,
|
||||||
help="The IO backend for KV cache transfer between CPU and GPU",
|
help="The IO backend for KV cache transfer between CPU and GPU",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hicache-mem-layout",
|
||||||
|
type=str,
|
||||||
|
choices=["layer_first", "page_first"],
|
||||||
|
default=ServerArgs.hicache_mem_layout,
|
||||||
|
help="The layout of host memory pool for hierarchical cache.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hicache-storage-backend",
|
"--hicache-storage-backend",
|
||||||
type=str,
|
type=str,
|
||||||
|
|||||||
Reference in New Issue
Block a user