Kernels for efficient KV cache IO (#7313)

This commit is contained in:
Zhiqiang Xie
2025-07-06 22:53:36 -07:00
committed by GitHub
parent 253454de9b
commit 2fc824b84c
7 changed files with 184 additions and 371 deletions

View File

@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
import concurrent.futures
import logging import logging
import math import math
import threading import threading
@@ -169,12 +168,23 @@ class HiCacheController:
page_size: int, page_size: int,
load_cache_event: threading.Event = None, load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
io_backend: str = "",
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
self.mem_pool_host = mem_pool_host self.mem_pool_host = mem_pool_host
self.write_policy = write_policy self.write_policy = write_policy
self.page_size = page_size self.page_size = page_size
# using kernel for small page KV cache transfer and DMA for large pages
if not io_backend:
IO_BACKEND_PAGE_SIZE_THRESHOLD = 64
self.io_backend = (
"direct"
if self.page_size >= IO_BACKEND_PAGE_SIZE_THRESHOLD
else "kernel"
)
else:
self.io_backend = io_backend
self.load_cache_event = load_cache_event self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -203,12 +213,7 @@ class HiCacheController:
self.load_stream = torch.cuda.Stream() self.load_stream = torch.cuda.Stream()
self.write_thread = threading.Thread( self.write_thread = threading.Thread(
target=( target=self.write_thread_func_direct, daemon=True
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
@@ -229,12 +234,7 @@ class HiCacheController:
self.ack_load_queue.queue.clear() self.ack_load_queue.queue.clear()
self.write_thread = threading.Thread( self.write_thread = threading.Thread(
target=( target=self.write_thread_func_direct, daemon=True
self.write_thread_func_buffer
if self.page_size == 1
else self.write_thread_func_direct
),
daemon=True,
) )
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
@@ -281,6 +281,15 @@ class HiCacheController:
) )
return device_indices return device_indices
def move_indices(self, host_indices, device_indices):
# move indices to GPU if using kernels, to host if using direct indexing
if self.io_backend == "kernel":
return host_indices.to(self.mem_pool_device.device), device_indices
elif self.io_backend == "direct":
return host_indices, device_indices.cpu()
else:
raise ValueError(f"Unsupported io backend")
def write_thread_func_direct(self): def write_thread_func_direct(self):
""" """
Directly write through KV caches to host memory without buffering. Directly write through KV caches to host memory without buffering.
@@ -289,10 +298,14 @@ class HiCacheController:
while not self.stop_event.is_set(): while not self.stop_event.is_set():
try: try:
operation = self.write_queue.get(block=True, timeout=1) operation = self.write_queue.get(block=True, timeout=1)
self.mem_pool_host.write_page_all_layers( host_indices, device_indices = self.move_indices(
operation.host_indices, operation.host_indices, operation.device_indices
operation.device_indices, )
self.mem_pool_device, self.mem_pool_device.backup_to_host_all_layer(
self.mem_pool_host,
host_indices,
device_indices,
self.io_backend,
) )
self.write_stream.synchronize() self.write_stream.synchronize()
self.mem_pool_host.complete_io(operation.host_indices) self.mem_pool_host.complete_io(operation.host_indices)
@@ -304,27 +317,6 @@ class HiCacheController:
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
def load_thread_func_direct(self):
"""
Directly load KV caches from host memory to device memory without buffering.
"""
torch.cuda.set_stream(self.load_stream)
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
operation.data = self.mem_pool_host.get_flat_data(
operation.host_indices
)
self.mem_pool_device.transfer(operation.device_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
except Empty:
continue
except Exception as e:
logger.error(e)
def load_thread_func_layer_by_layer(self): def load_thread_func_layer_by_layer(self):
""" """
Load KV caches from host memory to device memory layer by layer. Load KV caches from host memory to device memory layer by layer.
@@ -349,22 +341,18 @@ class HiCacheController:
# start layer-wise KV cache transfer from CPU to GPU # start layer-wise KV cache transfer from CPU to GPU
self.layer_done_counter.reset() self.layer_done_counter.reset()
host_indices, device_indices = self.move_indices(
batch_operation.host_indices, batch_operation.device_indices
)
for i in range(self.mem_pool_host.layer_num): for i in range(self.mem_pool_host.layer_num):
if self.page_size == 1: self.mem_pool_device.load_from_host_per_layer(
flat_data = self.mem_pool_host.get_flat_data_by_layer( self.mem_pool_host,
batch_operation.host_indices, i host_indices,
) device_indices,
self.mem_pool_device.transfer_per_layer( i,
batch_operation.device_indices, flat_data, i self.io_backend,
) )
else: self.load_stream.synchronize()
self.mem_pool_host.load_page_per_layer(
batch_operation.host_indices,
batch_operation.device_indices,
self.mem_pool_device,
i,
)
self.load_stream.synchronize()
self.layer_done_counter.increment() self.layer_done_counter.increment()
self.mem_pool_host.complete_io(batch_operation.host_indices) self.mem_pool_host.complete_io(batch_operation.host_indices)
@@ -372,148 +360,6 @@ class HiCacheController:
if node_id != 0: if node_id != 0:
self.ack_load_queue.put(node_id) self.ack_load_queue.put(node_id)
def write_aux_func(self, no_wait=False):
"""
Auxiliary function to prepare the buffer for write operations.
"""
torch.cuda.set_stream(self.write_stream)
def _to_op(op_):
assert op_.device_indices.is_cuda, "Device indices should be on GPU"
op_.data = self.mem_pool_device.get_flat_data(op_.device_indices).to(
self.mem_pool_host.device
)
self.write_buffer.put(op_)
return op_
buffer = None
while not self.stop_event.is_set():
try:
operation = self.write_queue.get(block=True, timeout=1)
factor = (
len(operation.device_indices) // self.write_buffer.max_buffer_size
)
if factor >= 1:
if buffer is not None:
_to_op(buffer)
buffer = None
if factor < 2:
_to_op(operation)
else:
split_ops = operation.split(factor)
for op_ in split_ops:
_to_op(op_)
continue
if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
no_wait
or len(buffer.host_indices) >= self.write_buffer.max_buffer_size
or self.write_queue.empty()
or self.write_buffer.empty()
):
_to_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
def load_aux_func(self):
"""
Auxiliary function to prepare the buffer for load operations.
"""
def _pin_op(op_, put=True):
op_.data = (
self.mem_pool_host.get_flat_data(op_.host_indices)
.contiguous()
.pin_memory()
)
if put:
self.load_buffer.put(op_)
return op_
buffer = None
while not self.stop_event.is_set():
try:
operation = self.load_queue.get(block=True, timeout=1)
factor = len(operation.host_indices) // self.load_buffer.max_buffer_size
if factor >= 1:
if buffer is not None:
_pin_op(buffer)
buffer = None
if factor < 2:
_pin_op(operation)
else:
split_ops = operation.split(factor)
split_args = [(op_, True) for op_ in split_ops[:-1]]
split_args.append((split_ops[-1], False))
# Spawn threads to pin each op concurrently
with concurrent.futures.ThreadPoolExecutor() as executor:
pinned_ops = list(
executor.map(
lambda x: _pin_op(x[0], put=x[1]), split_args
)
)
# preserve the order of last op to ensure correct ack
self.load_buffer.put(pinned_ops[-1])
continue
if buffer is None:
buffer = operation
else:
buffer.merge(operation)
if (
len(buffer.host_indices) >= self.load_buffer.max_buffer_size
or self.load_queue.empty()
or self.load_buffer.empty()
):
_pin_op(buffer)
buffer = None
except Empty:
continue
except Exception as e:
logger.error(e)
# todo (zhiqiang): double buffering to be deprecated
def write_thread_func_buffer(self):
aux_thread = threading.Thread(target=self.write_aux_func, daemon=True)
aux_thread.start()
while not self.stop_event.is_set():
operation = self.write_buffer.get()
if operation is None:
continue
self.mem_pool_host.assign_flat_data(operation.host_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_write_queue.put(node_id)
aux_thread.join()
def load_thread_func_buffer(self):
torch.cuda.set_stream(self.load_stream)
aux_thread = threading.Thread(target=self.load_aux_func, daemon=True)
aux_thread.start()
while not self.stop_event.is_set():
operation = self.load_buffer.get()
if operation is None:
continue
self.mem_pool_device.transfer(operation.device_indices, operation.data)
self.mem_pool_host.complete_io(operation.host_indices)
for node_id in operation.node_ids:
if node_id != 0:
self.ack_load_queue.put(node_id)
aux_thread.join()
def evict_device( def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int: ) -> int:

View File

@@ -591,6 +591,12 @@ class Scheduler(
hicache_ratio=server_args.hicache_ratio, hicache_ratio=server_args.hicache_ratio,
hicache_size=server_args.hicache_size, hicache_size=server_args.hicache_size,
hicache_write_policy=server_args.hicache_write_policy, hicache_write_policy=server_args.hicache_write_policy,
hicache_io_backend=(
"direct"
if server_args.attention_backend
== "fa3" # hot fix for incompatibility
else server_args.hicache_io_backend
),
) )
self.tp_worker.register_hicache_layer_transfer_counter( self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter

View File

@@ -34,6 +34,7 @@ class HiRadixCache(RadixCache):
hicache_ratio: float, hicache_ratio: float,
hicache_size: int, hicache_size: int,
hicache_write_policy: str, hicache_write_policy: str,
hicache_io_backend: str,
): ):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache() self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool): if isinstance(self.kv_cache, MHATokenToKVPool):
@@ -56,6 +57,7 @@ class HiRadixCache(RadixCache):
page_size, page_size,
load_cache_event=self.load_cache_event, load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy, write_policy=hicache_write_policy,
io_backend=hicache_io_backend,
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through

View File

@@ -34,10 +34,11 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2 from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -150,13 +151,16 @@ class KVCache(abc.ABC):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
def get_flat_data(self, indices): @abc.abstractmethod
def load_from_host_per_layer(
self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError() raise NotImplementedError()
def transfer(self, indices, flat_data): @abc.abstractmethod
raise NotImplementedError() def backup_to_host_all_layer(
self, host_pool, host_indices, device_indices, io_backend
def transfer_per_layer(self, indices, flat_data, layer_id): ):
raise NotImplementedError() raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter): def register_layer_transfer_counter(self, layer_transfer_counter):
@@ -247,7 +251,7 @@ class MHATokenToKVPool(KVCache):
) )
for _ in range(self.layer_num) for _ in range(self.layer_num)
] ]
self.token_stride = self.head_num * self.head_dim
self.data_ptrs = torch.tensor( self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer + self.v_buffer], [x.data_ptr() for x in self.k_buffer + self.v_buffer],
dtype=torch.uint64, dtype=torch.uint64,
@@ -281,24 +285,24 @@ class MHATokenToKVPool(KVCache):
# layer_num x [seq_len, head_num, head_dim] # layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim] # layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [ kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() self._get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [ ] + [
self.get_value_buffer(i).data_ptr() self._get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
kv_data_lens = [ kv_data_lens = [
self.get_key_buffer(i).nbytes self._get_key_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [ ] + [
self.get_value_buffer(i).nbytes self._get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
kv_item_lens = [ kv_item_lens = [
self.get_key_buffer(i)[0].nbytes * self.page_size self._get_key_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [ ] + [
self.get_value_buffer(i)[0].nbytes * self.page_size self._get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
@@ -341,49 +345,73 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id][chunk_indices] = v_chunk self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize() torch.cuda.synchronize()
# Todo: different memory layout def load_from_host_per_layer(
def get_flat_data(self, indices): self,
# prepare a large chunk of contiguous data for efficient transfer host_pool,
flatten = torch.stack( host_indices,
[ device_indices,
torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]), layer_id,
torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]), io_backend,
] ):
transfer_kv_per_layer(
src_k=host_pool.k_buffer[layer_id],
dst_k=self.k_buffer[layer_id],
src_v=host_pool.v_buffer[layer_id],
dst_v=self.v_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
) )
return flatten
@debug_timing def backup_to_host_all_layer(
def transfer(self, indices, flat_data): self, host_pool, host_indices, device_indices, io_backend
# transfer prepared data from host to device ):
flat_data = flat_data.to(device=self.device, non_blocking=False) # todo: specialized all layer kernels for the layer-non-contiguous memory pool
k_data, v_data = flat_data[0], flat_data[1] for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
for i in range(self.layer_num): if layer_id - self.start_layer >= len(host_pool.k_buffer):
self.k_buffer[i][indices] = k_data[i] raise ValueError(
self.v_buffer[i][indices] = v_data[i] f"Layer ID {layer_id} exceeds the number of layers in host pool."
)
def transfer_per_layer(self, indices, flat_data, layer_id): transfer_kv_per_layer(
# transfer prepared data from host to device src_k=self.k_buffer[layer_id],
flat_data = flat_data.to(device=self.device, non_blocking=False) dst_k=host_pool.k_buffer[layer_id],
k_data, v_data = flat_data[0], flat_data[1] src_v=self.v_buffer[layer_id],
self.k_buffer[layer_id - self.start_layer][indices] = k_data dst_v=host_pool.v_buffer[layer_id],
self.v_buffer[layer_id - self.start_layer][indices] = v_data src_indices=device_indices,
dst_indices=host_indices,
def get_key_buffer(self, layer_id: int): io_backend=io_backend,
if self.layer_transfer_counter is not None: page_size=self.page_size,
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) item_size=self.token_stride,
)
def _get_key_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.k_buffer[layer_id - self.start_layer].view(self.dtype) return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer] return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
# it is supposed to be used only by attention backend not for information purpose
# same applies to get_value_buffer and get_kv_buffer
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_key_buffer(layer_id)
def _get_value_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
return self.v_buffer[layer_id - self.start_layer].view(self.dtype) return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer] return self.v_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_value_buffer(layer_id)
def get_kv_buffer(self, layer_id: int): def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
@@ -761,6 +789,7 @@ class MLATokenToKVPool(KVCache):
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.token_stride = kv_lora_rank + qk_rope_head_dim
self.layer_transfer_counter = None self.layer_transfer_counter = None
kv_size = self.get_kv_size_bytes() kv_size = self.get_kv_size_bytes()
@@ -846,21 +875,37 @@ class MLATokenToKVPool(KVCache):
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
) )
def get_flat_data(self, indices): def load_from_host_per_layer(
# prepare a large chunk of contiguous data for efficient transfer self, host_pool, host_indices, device_indices, layer_id, io_backend
return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)]) ):
transfer_kv_per_layer_mla(
src=host_pool.kv_buffer[layer_id],
dst=self.kv_buffer[layer_id],
src_indices=host_indices,
dst_indices=device_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
@debug_timing def backup_to_host_all_layer(
def transfer(self, indices, flat_data): self, host_pool, host_indices, device_indices, io_backend
# transfer prepared data from host to device ):
flat_data = flat_data.to(device=self.device, non_blocking=False) # todo: specialized all layer kernels for the layer-non-contiguous memory pool
for i in range(self.layer_num): for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
self.kv_buffer[i][indices] = flat_data[i] if layer_id - self.start_layer >= len(host_pool.kv_buffer):
raise ValueError(
def transfer_per_layer(self, indices, flat_data, layer_id): f"Layer ID {layer_id} exceeds the number of layers in host pool."
# transfer prepared data from host to device )
flat_data = flat_data.to(device=self.device, non_blocking=False) transfer_kv_per_layer_mla(
self.kv_buffer[layer_id - self.start_layer][indices] = flat_data src=self.kv_buffer[layer_id],
dst=host_pool.kv_buffer[layer_id],
src_indices=device_indices,
dst_indices=host_indices,
io_backend=io_backend,
page_size=self.page_size,
item_size=self.token_stride,
)
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
@@ -1046,14 +1091,19 @@ class DoubleSparseTokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id - self.start_layer][loc] = cache_label self.label_buffer[layer_id - self.start_layer][loc] = cache_label
def get_flat_data(self, indices): def load_from_host_per_layer(
pass self, host_pool, host_indices, device_indices, layer_id, io_backend
):
raise NotImplementedError(
"HiCache not supported for DoubleSparseTokenToKVPool."
)
def transfer(self, indices, flat_data): def backup_to_host_all_layer(
pass self, host_pool, host_indices, device_indices, io_backend
):
def transfer_per_layer(self, indices, flat_data, layer_id): raise NotImplementedError(
pass "HiCache not supported for DoubleSparseTokenToKVPool."
)
@triton.jit @triton.jit

View File

@@ -8,7 +8,6 @@ import psutil
import torch import torch
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
from sglang.srt.utils import debug_timing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
def init_kv_buffer(self): def init_kv_buffer(self):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def transfer(self, indices, flat_data):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data(self, indices):
raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_by_layer(self, indices, layer_id):
raise NotImplementedError()
@abc.abstractmethod
def assign_flat_data(self, indices, flat_data):
raise NotImplementedError()
@synchronized() @synchronized()
def clear(self): def clear(self):
# Initialize memory states and tracking structures. # Initialize memory states and tracking structures.
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
@debug_timing @property
def transfer(self, indices, flat_data): def k_buffer(self):
# backup prepared data from device to host return self.kv_buffer[0]
self.kv_buffer[:, :, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices): @property
return self.kv_buffer[:, :, indices] def v_buffer(self):
return self.kv_buffer[1]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
device_pool.k_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
device_pool.v_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.k_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
0, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
device_pool.v_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
1, layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)
class MLATokenToKVPoolHost(HostKVCache): class MLATokenToKVPoolHost(HostKVCache):
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
@debug_timing
def transfer(self, indices, flat_data):
# backup prepared data from device to host
self.kv_buffer[:, indices] = flat_data.to(
device=self.device, non_blocking=False
)
def get_flat_data(self, indices):
return self.kv_buffer[:, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[layer_id - self.start_layer, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, indices] = flat_data
def write_page_all_layers(self, host_indices, device_indices, device_pool):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
for j in range(self.layer_num):
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
non_blocking=True,
)
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
device_indices_cpu = device_indices[:: self.page_size].cpu()
for i in range(len(device_indices_cpu)):
h_index = host_indices[i * self.page_size]
d_index = device_indices_cpu[i]
device_pool.kv_buffer[layer_id - self.start_layer][
d_index : d_index + self.page_size
].copy_(
self.kv_buffer[
layer_id - self.start_layer, h_index : h_index + self.page_size
],
non_blocking=True,
)

View File

@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
if self.page_size != 1: if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else: else:
page_aligned_len = len(kv_indices) page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone() page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert( new_prefix_len = self.insert(
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
if self.page_size != 1: if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone() page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True
)
else: else:
page_aligned_len = len(kv_indices) page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone() page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
page_aligned_token_ids = token_ids[:page_aligned_len] page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool

View File

@@ -217,6 +217,7 @@ class ServerArgs:
hicache_ratio: float = 2.0 hicache_ratio: float = 2.0
hicache_size: int = 0 hicache_size: int = 0
hicache_write_policy: str = "write_through_selective" hicache_write_policy: str = "write_through_selective"
hicache_io_backend: str = ""
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
disable_shared_experts_fusion: bool = False disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
@@ -1530,6 +1531,13 @@ class ServerArgs:
default=ServerArgs.hicache_write_policy, default=ServerArgs.hicache_write_policy,
help="The write policy of hierarchical cache.", help="The write policy of hierarchical cache.",
) )
parser.add_argument(
"--hicache-io-backend",
type=str,
choices=["direct", "kernel"],
default=ServerArgs.hicache_io_backend,
help="The IO backend for KV cache transfer between CPU and GPU",
)
parser.add_argument( parser.add_argument(
"--flashinfer-mla-disable-ragged", "--flashinfer-mla-disable-ragged",
action="store_true", action="store_true",