[PD-HiCache]: Support Async Offloading KVCache In Decode Side (#10192)

Signed-off-by: Shangming Cai <csmthu@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
hzh0425
2025-09-26 14:20:49 +08:00
committed by GitHub
parent 6088548216
commit 7ec5b4e89c
7 changed files with 523 additions and 8 deletions

View File

@@ -609,6 +609,7 @@ class DecodeTransferQueue:
idx = decode_req.metadata_buffer_index
(
output_id,
cached_tokens,
output_token_logprobs_val,
output_token_logprobs_idx,
output_top_logprobs_val,
@@ -617,6 +618,7 @@ class DecodeTransferQueue:
) = self.metadata_buffers.get_buf(idx)
decode_req.req.output_ids.append(output_id[0].item())
decode_req.req.cached_tokens = cached_tokens[0].item()
if not self.spec_algorithm.is_none():
decode_req.req.hidden_states_tensor = output_hidden_states
if decode_req.req.return_logprob:
@@ -707,12 +709,15 @@ class SchedulerDisaggregationDecodeMixin:
elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and (
queue_size = (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
queue_size += len(self.decode_offload_manager.ongoing_offload)
if batch is None and queue_size == 0:
self.self_check_during_idle()
self.last_batch = batch
@@ -781,12 +786,15 @@ class SchedulerDisaggregationDecodeMixin:
)
self.process_batch_result(tmp_batch, tmp_result)
if batch is None and (
queue_size = (
len(self.waiting_queue)
+ len(self.disagg_decode_transfer_queue.queue)
+ len(self.disagg_decode_prealloc_queue.queue)
== 0
):
)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
queue_size += len(self.decode_offload_manager.ongoing_offload)
if batch is None and queue_size == 0:
self.self_check_during_idle()
self.last_batch = batch
@@ -905,3 +913,6 @@ class SchedulerDisaggregationDecodeMixin:
self.disagg_decode_transfer_queue.pop_transferred()
) # the requests which kv has arrived
self.waiting_queue.extend(alloc_reqs)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
self.decode_offload_manager.check_offload_progress()

View File

@@ -0,0 +1,185 @@
import logging
import threading
import time
import torch
from sglang import ServerArgs
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
MLATokenToKVPoolHost,
)
logger = logging.getLogger(__name__)
class DecodeKVCacheOffloadManager:
"""Manage decode-side KV cache offloading lifecycle and operations."""
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tp_group: torch.distributed.ProcessGroup,
tree_cache: BasePrefixCache,
server_args: ServerArgs,
) -> None:
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = server_args.page_size
self.server_args = server_args
self.request_counter = 0
self.tree_cache = tree_cache
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
if isinstance(kv_cache, MHATokenToKVPool):
self.decode_host_mem_pool = MHATokenToKVPoolHost(
kv_cache,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
elif isinstance(kv_cache, MLATokenToKVPool):
self.decode_host_mem_pool = MLATokenToKVPoolHost(
kv_cache,
server_args.hicache_ratio,
server_args.hicache_size,
self.page_size,
server_args.hicache_mem_layout,
)
else:
raise ValueError("Unsupported KV cache type for decode offload")
self.tp_group = tp_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.cache_controller = HiCacheController(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
mem_pool_host=self.decode_host_mem_pool,
page_size=self.page_size,
tp_group=tp_group,
io_backend=server_args.hicache_io_backend,
load_cache_event=threading.Event(),
storage_backend=server_args.hicache_storage_backend,
model_name=server_args.served_model_name,
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
)
self.ongoing_offload = {}
self.ongoing_backup = {}
logger.info("Enable offload kv cache for decode side")
def offload_kv_cache(self, req) -> bool:
"""Offload a finished request's KV cache to storage."""
if self.cache_controller is None or self.decode_host_mem_pool is None:
return False
if req.req_pool_idx == -1:
return False
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
if token_indices.dim() == 0 or token_indices.numel() == 0:
logger.debug(
f"Request {req.rid} has invalid token_indices: {token_indices}"
)
return False
tokens = req.origin_input_ids + req.output_ids
aligned_len = (len(tokens) // self.page_size) * self.page_size
if aligned_len == 0:
return False
token_indices = token_indices[:aligned_len]
tokens = tokens[:aligned_len]
# Asynchronously offload KV cache from device to host by cache controller
self.request_counter += 1
ack_id = self.request_counter
host_indices = self.cache_controller.write(
device_indices=token_indices.long(),
node_id=ack_id,
)
if host_indices is None:
logger.error(f"Not enough host memory for request {req.rid}")
return False
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
return True
def check_offload_progress(self):
"""Check the progress of offload from device to host and backup from host to storage."""
cc = self.cache_controller
qsizes = torch.tensor(
[
len(cc.ack_write_queue),
cc.ack_backup_queue.qsize(),
],
dtype=torch.int,
)
if self.tp_world_size > 1:
torch.distributed.all_reduce(
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
)
n_write, n_backup = map(int, qsizes.tolist())
self._check_offload_progress(n_write)
self._check_backup_progress(n_backup)
def _check_offload_progress(self, finish_count):
"""Check the progress of offload from device to host."""
while finish_count > 0:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize()
for ack_id in ack_list:
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
# Release device
self.tree_cache.cache_finished_req(req)
# Trigger async backup from host to storage by cache controller
self._trigger_backup(req.rid, host_indices, tokens, start_time)
finish_count -= 1
def _check_backup_progress(self, finish_count):
"""Check the progress of backup from host to storage."""
for _ in range(finish_count):
storage_operation = self.cache_controller.ack_backup_queue.get()
ack_id = storage_operation.id
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
# Release host memory
self.decode_host_mem_pool.free(host_indices)
logger.debug(
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
)
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
"""Trigger async backup from host to storage by cache controller."""
# Generate page hashes and write to storage
page_hashes = self._compute_prefix_hash(tokens)
ack_id = self.cache_controller.write_storage(
host_indices,
tokens,
hash_value=page_hashes,
)
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
def _compute_prefix_hash(self, tokens):
last_hash = ""
page_hashes = []
for offset in range(0, len(tokens), self.page_size):
page_tokens = tokens[offset : offset + self.page_size]
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
page_hashes.append(last_hash)
return page_hashes

View File

@@ -107,7 +107,9 @@ class MetadataBuffers:
# We transfer the metadata of first output token to decode
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
self.cached_tokens = torch.zeros(
(size, 16), dtype=torch.int32, device=device
)
self.output_token_logprobs_val = torch.zeros(
(size, 16), dtype=torch.float32, device=device
)
@@ -127,6 +129,7 @@ class MetadataBuffers:
def get_buf_infos(self):
ptrs = [
self.output_ids.data_ptr(),
self.cached_tokens.data_ptr(),
self.output_token_logprobs_val.data_ptr(),
self.output_token_logprobs_idx.data_ptr(),
self.output_top_logprobs_val.data_ptr(),
@@ -135,6 +138,7 @@ class MetadataBuffers:
]
data_lens = [
self.output_ids.nbytes,
self.cached_tokens.nbytes,
self.output_token_logprobs_val.nbytes,
self.output_token_logprobs_idx.nbytes,
self.output_top_logprobs_val.nbytes,
@@ -143,6 +147,7 @@ class MetadataBuffers:
]
item_lens = [
self.output_ids[0].nbytes,
self.cached_tokens[0].nbytes,
self.output_token_logprobs_val[0].nbytes,
self.output_token_logprobs_idx[0].nbytes,
self.output_top_logprobs_val[0].nbytes,
@@ -154,6 +159,7 @@ class MetadataBuffers:
def get_buf(self, idx: int):
return (
self.output_ids[idx],
self.cached_tokens[idx],
self.output_token_logprobs_val[idx],
self.output_token_logprobs_idx[idx],
self.output_top_logprobs_val[idx],
@@ -164,6 +170,7 @@ class MetadataBuffers:
def set_buf(self, req: Req):
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
if req.return_logprob:
if req.output_token_logprobs_val: # not none or empty list
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (

View File

@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
DecodeKVCacheOffloadManager,
)
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
@@ -755,6 +758,24 @@ class Scheduler(
eviction_policy=server_args.radix_eviction_policy,
)
if (
server_args.disaggregation_mode == "decode"
and server_args.disaggregation_decode_enable_offload_kvcache
):
self.decode_offload_manager = DecodeKVCacheOffloadManager(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_group=(
self.attn_tp_cpu_group
if self.server_args.enable_dp_attention
else self.tp_cpu_group
),
tree_cache=self.tree_cache,
server_args=self.server_args,
)
else:
self.decode_offload_manager = None
self.decode_mem_cache_buf_multiplier = (
1
if self.spec_algorithm.is_none()

View File

@@ -250,7 +250,13 @@ class SchedulerOutputProcessorMixin:
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
if self.server_args.disaggregation_decode_enable_offload_kvcache:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
if not self.decode_offload_manager.offload_kv_cache(req):
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time()
if req.return_logprob and batch.spec_algorithm.is_none():

View File

@@ -421,6 +421,7 @@ class ServerArgs:
disaggregation_decode_dp: Optional[int] = None
disaggregation_prefill_pp: Optional[int] = 1
disaggregation_ib_device: Optional[str] = None
disaggregation_decode_enable_offload_kvcache: bool = False
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
# FIXME: hack to reduce ITL when decode bs is small
disaggregation_decode_polling_interval: int = 1
@@ -1074,6 +1075,14 @@ class ServerArgs:
"and cannot be used at the same time. Please use only one of them."
)
if (
self.disaggregation_decode_enable_offload_kvcache
and self.disaggregation_mode != "decode"
):
raise ValueError(
"The argument disaggregation-decode-enable-offload-kvcache is only supported for decode side."
)
def _handle_metrics_labels(self):
if (
not self.tokenizer_metrics_custom_labels_header
@@ -2556,6 +2565,11 @@ class ServerArgs:
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
)
parser.add_argument(
"--disaggregation-decode-enable-offload-kvcache",
action="store_true",
help="Enable async KV cache offloading on decode server (PD mode).",
)
parser.add_argument(
"--num-reserved-decode-tokens",
type=int,