Kernels for efficient KV cache IO (#7313)
This commit is contained in:
@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
|
||||
hicache_ratio: float,
|
||||
hicache_size: int,
|
||||
hicache_write_policy: str,
|
||||
hicache_io_backend: str,
|
||||
):
|
||||
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(self.kv_cache, MHATokenToKVPool):
|
||||
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
|
||||
page_size,
|
||||
load_cache_event=self.load_cache_event,
|
||||
write_policy=hicache_write_policy,
|
||||
io_backend=hicache_io_backend,
|
||||
)
|
||||
|
||||
# record the nodes with ongoing write through
|
||||
|
||||
@@ -34,10 +34,11 @@ import torch
|
||||
import torch.distributed as dist
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
||||
|
||||
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
|
||||
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
@abc.abstractmethod
|
||||
def load_from_host_per_layer(
|
||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def transfer(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
@abc.abstractmethod
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
raise NotImplementedError()
|
||||
|
||||
def register_layer_transfer_counter(self, layer_transfer_counter):
|
||||
@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
|
||||
self.token_stride = self.head_num * self.head_dim
|
||||
self.data_ptrs = torch.tensor(
|
||||
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
||||
dtype=torch.uint64,
|
||||
@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
|
||||
# layer_num x [seq_len, head_num, head_dim]
|
||||
# layer_num x [page_num, page_size, head_num, head_dim]
|
||||
kv_data_ptrs = [
|
||||
self.get_key_buffer(i).data_ptr()
|
||||
self._get_key_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).data_ptr()
|
||||
self._get_value_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_data_lens = [
|
||||
self.get_key_buffer(i).nbytes
|
||||
self._get_key_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).nbytes
|
||||
self._get_value_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_item_lens = [
|
||||
self.get_key_buffer(i)[0].nbytes * self.page_size
|
||||
self._get_key_buffer(i)[0].nbytes * self.page_size
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i)[0].nbytes * self.page_size
|
||||
self._get_value_buffer(i)[0].nbytes * self.page_size
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
|
||||
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Todo: different memory layout
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
flatten = torch.stack(
|
||||
[
|
||||
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
|
||||
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
|
||||
]
|
||||
def load_from_host_per_layer(
|
||||
self,
|
||||
host_pool,
|
||||
host_indices,
|
||||
device_indices,
|
||||
layer_id,
|
||||
io_backend,
|
||||
):
|
||||
transfer_kv_per_layer(
|
||||
src_k=host_pool.k_buffer[layer_id],
|
||||
dst_k=self.k_buffer[layer_id],
|
||||
src_v=host_pool.v_buffer[layer_id],
|
||||
dst_v=self.v_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
return flatten
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
for i in range(self.layer_num):
|
||||
self.k_buffer[i][indices] = k_data[i]
|
||||
self.v_buffer[i][indices] = v_data[i]
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
k_data, v_data = flat_data[0], flat_data[1]
|
||||
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
||||
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
||||
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
||||
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
||||
raise ValueError(
|
||||
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
||||
)
|
||||
transfer_kv_per_layer(
|
||||
src_k=self.k_buffer[layer_id],
|
||||
dst_k=host_pool.k_buffer[layer_id],
|
||||
src_v=self.v_buffer[layer_id],
|
||||
dst_v=host_pool.v_buffer[layer_id],
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
|
||||
def _get_key_buffer(self, layer_id: int):
|
||||
# for internal use of referencing
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.k_buffer[layer_id - self.start_layer]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
|
||||
# it is supposed to be used only by attention backend not for information purpose
|
||||
# same applies to get_value_buffer and get_kv_buffer
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
|
||||
return self._get_key_buffer(layer_id)
|
||||
|
||||
def _get_value_buffer(self, layer_id: int):
|
||||
# for internal use of referencing
|
||||
if self.store_dtype != self.dtype:
|
||||
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
||||
return self.v_buffer[layer_id - self.start_layer]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
if self.layer_transfer_counter is not None:
|
||||
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
||||
return self._get_value_buffer(layer_id)
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
||||
|
||||
@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.token_stride = kv_lora_rank + qk_rope_head_dim
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
kv_size = self.get_kv_size_bytes()
|
||||
@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
|
||||
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
# prepare a large chunk of contiguous data for efficient transfer
|
||||
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
|
||||
def load_from_host_per_layer(
|
||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
transfer_kv_per_layer_mla(
|
||||
src=host_pool.kv_buffer[layer_id],
|
||||
dst=self.kv_buffer[layer_id],
|
||||
src_indices=host_indices,
|
||||
dst_indices=device_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
for i in range(self.layer_num):
|
||||
self.kv_buffer[i][indices] = flat_data[i]
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
# transfer prepared data from host to device
|
||||
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
||||
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
||||
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
||||
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
||||
raise ValueError(
|
||||
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
||||
)
|
||||
transfer_kv_per_layer_mla(
|
||||
src=self.kv_buffer[layer_id],
|
||||
dst=host_pool.kv_buffer[layer_id],
|
||||
src_indices=device_indices,
|
||||
dst_indices=host_indices,
|
||||
io_backend=io_backend,
|
||||
page_size=self.page_size,
|
||||
item_size=self.token_stride,
|
||||
)
|
||||
|
||||
def get_cpu_copy(self, indices):
|
||||
torch.cuda.synchronize()
|
||||
@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
pass
|
||||
def load_from_host_per_layer(
|
||||
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"HiCache not supported for DoubleSparseTokenToKVPool."
|
||||
)
|
||||
|
||||
def transfer(self, indices, flat_data):
|
||||
pass
|
||||
|
||||
def transfer_per_layer(self, indices, flat_data, layer_id):
|
||||
pass
|
||||
def backup_to_host_all_layer(
|
||||
self, host_pool, host_indices, device_indices, io_backend
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"HiCache not supported for DoubleSparseTokenToKVPool."
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
|
||||
@@ -8,7 +8,6 @@ import psutil
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
||||
from sglang.srt.utils import debug_timing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
|
||||
def init_kv_buffer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def transfer(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data(self, indices):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized()
|
||||
def clear(self):
|
||||
# Initialize memory states and tracking structures.
|
||||
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, :, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
@property
|
||||
def k_buffer(self):
|
||||
return self.kv_buffer[0]
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, :, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, :, indices] = flat_data
|
||||
|
||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
for j in range(self.layer_num):
|
||||
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
device_pool.k_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
device_pool.v_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
@property
|
||||
def v_buffer(self):
|
||||
return self.kv_buffer[1]
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
)
|
||||
|
||||
@debug_timing
|
||||
def transfer(self, indices, flat_data):
|
||||
# backup prepared data from device to host
|
||||
self.kv_buffer[:, indices] = flat_data.to(
|
||||
device=self.device, non_blocking=False
|
||||
)
|
||||
|
||||
def get_flat_data(self, indices):
|
||||
return self.kv_buffer[:, indices]
|
||||
|
||||
def get_flat_data_by_layer(self, indices, layer_id):
|
||||
return self.kv_buffer[layer_id - self.start_layer, indices]
|
||||
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
self.kv_buffer[:, indices] = flat_data
|
||||
|
||||
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
for j in range(self.layer_num):
|
||||
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
||||
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
||||
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
||||
for i in range(len(device_indices_cpu)):
|
||||
h_index = host_indices[i * self.page_size]
|
||||
d_index = device_indices_cpu[i]
|
||||
device_pool.kv_buffer[layer_id - self.start_layer][
|
||||
d_index : d_index + self.page_size
|
||||
].copy_(
|
||||
self.kv_buffer[
|
||||
layer_id - self.start_layer, h_index : h_index + self.page_size
|
||||
],
|
||||
non_blocking=True,
|
||||
)
|
||||
|
||||
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_kv_indices = kv_indices.clone()
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
|
||||
Reference in New Issue
Block a user