Kernels for efficient KV cache IO (#7313)

This commit is contained in:
Zhiqiang Xie
2025-07-06 22:53:36 -07:00
committed by GitHub
parent 253454de9b
commit 2fc824b84c
7 changed files with 184 additions and 371 deletions

View File

@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
hicache_ratio: float,
hicache_size: int,
hicache_write_policy: str,
hicache_io_backend: str,
):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool):
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
page_size,
load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
)
# record the nodes with ongoing write through

View File

@@ -34,10 +34,11 @@ import torch
import torch.distributed as dist
import triton
import triton.language as tl
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
logger = logging.getLogger(__name__)
@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
) -> None:
raise NotImplementedError()
def get_flat_data(self, indices):
@abc.abstractmethod
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError()
def transfer(self, indices, flat_data):
raise NotImplementedError()
def transfer_per_layer(self, indices, flat_data, layer_id):
@abc.abstractmethod
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter):
@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
)
for _ in range(self.layer_num)
]
self.token_stride = self.head_num * self.head_dim
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64,
@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr()
self._get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).data_ptr()
self._get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_data_lens = [
self.get_key_buffer(i).nbytes
self._get_key_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).nbytes
self._get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_item_lens = [
self.get_key_buffer(i)[0].nbytes * self.page_size
self._get_key_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i)[0].nbytes * self.page_size
self._get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize()
# Todo: different memory layout
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
flatten = torch.stack(
[
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
]
def load_from_host_per_layer(
self,
host_pool,
host_indices,
device_indices,
layer_id,
io_backend,
):
transfer_kv_per_layer(
src_k=host_pool.k_buffer[layer_id],
dst_k=self.k_buffer[layer_id],
src_v=host_pool.v_buffer[layer_id],
dst_v=self.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
return flatten
@debug_timing
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
for i in range(self.layer_num):
self.k_buffer[i][indices] = k_data[i]
self.v_buffer[i][indices] = v_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
k_data, v_data = flat_data[0], flat_data[1]
self.k_buffer[layer_id - self.start_layer][indices] = k_data
self.v_buffer[layer_id - self.start_layer][indices] = v_data
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.k_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer(
src_k=self.k_buffer[layer_id],
dst_k=host_pool.k_buffer[layer_id],
src_v=self.v_buffer[layer_id],
dst_v=host_pool.v_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def _get_key_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
def get_key_buffer(self, layer_id: int):
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
# it is supposed to be used only by attention backend not for information purpose
# same applies to get_value_buffer and get_kv_buffer
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_key_buffer(layer_id)
def _get_value_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_value_buffer(layer_id)
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
for _ in range(layer_num)
]
self.token_stride = kv_lora_rank + qk_rope_head_dim
self.layer_transfer_counter = None
kv_size = self.get_kv_size_bytes()
@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
)
def get_flat_data(self, indices):
# prepare a large chunk of contiguous data for efficient transfer
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
transfer_kv_per_layer_mla(
src=host_pool.kv_buffer[layer_id],
dst=self.kv_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
@debug_timing
def transfer(self, indices, flat_data):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
for i in range(self.layer_num):
self.kv_buffer[i][indices] = flat_data[i]
def transfer_per_layer(self, indices, flat_data, layer_id):
# transfer prepared data from host to device
flat_data = flat_data.to(device=self.device, non_blocking=False)
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
raise ValueError(
f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
transfer_kv_per_layer_mla(
src=self.kv_buffer[layer_id],
dst=host_pool.kv_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
def get_flat_data(self, indices):
pass
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
def transfer(self, indices, flat_data):
pass
def transfer_per_layer(self, indices, flat_data, layer_id):
pass
def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
@triton.jit

View File

@@ -8,7 +8,6 @@ import psutil
import torch
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import debug_timing
logger = logging.getLogger(__name__)
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
def init_kv_buffer(self):
raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id):
raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data):
raise NotImplementedError()
@synchronized()
def clear(self):
# Initialize memory states and tracking structures.
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
@property
def k_buffer(self):
return self.kv_buffer[0]
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
device_pool.k_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
device_pool.v_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
device_pool.v_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
@property
def v_buffer(self):
return self.kv_buffer[1]
class MLATokenToKVPoolHost(HostKVCache):
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
device=self.device,
pin_memory=self.pin_memory,
)
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)

View File

@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool