[PD-HiCache]: Support Async Offloading KVCache In Decode Side (#10192)
Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -609,6 +609,7 @@ class DecodeTransferQueue:
|
||||
idx = decode_req.metadata_buffer_index
|
||||
(
|
||||
output_id,
|
||||
cached_tokens,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
@@ -617,6 +618,7 @@ class DecodeTransferQueue:
|
||||
) = self.metadata_buffers.get_buf(idx)
|
||||
|
||||
decode_req.req.output_ids.append(output_id[0].item())
|
||||
decode_req.req.cached_tokens = cached_tokens[0].item()
|
||||
if not self.spec_algorithm.is_none():
|
||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||
if decode_req.req.return_logprob:
|
||||
@@ -707,12 +709,15 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
elif prepare_mlp_sync_flag:
|
||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||
|
||||
if batch is None and (
|
||||
queue_size = (
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
)
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
||||
|
||||
if batch is None and queue_size == 0:
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
@@ -781,12 +786,15 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
|
||||
if batch is None and (
|
||||
queue_size = (
|
||||
len(self.waiting_queue)
|
||||
+ len(self.disagg_decode_transfer_queue.queue)
|
||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||
== 0
|
||||
):
|
||||
)
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
||||
|
||||
if batch is None and queue_size == 0:
|
||||
self.self_check_during_idle()
|
||||
|
||||
self.last_batch = batch
|
||||
@@ -905,3 +913,6 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
self.disagg_decode_transfer_queue.pop_transferred()
|
||||
) # the requests which kv has arrived
|
||||
self.waiting_queue.extend(alloc_reqs)
|
||||
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
self.decode_offload_manager.check_offload_progress()
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
from sglang import ServerArgs
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool_host import (
|
||||
MHATokenToKVPoolHost,
|
||||
MLATokenToKVPoolHost,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DecodeKVCacheOffloadManager:
|
||||
"""Manage decode-side KV cache offloading lifecycle and operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||
tp_group: torch.distributed.ProcessGroup,
|
||||
tree_cache: BasePrefixCache,
|
||||
server_args: ServerArgs,
|
||||
) -> None:
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.page_size = server_args.page_size
|
||||
self.server_args = server_args
|
||||
self.request_counter = 0
|
||||
self.tree_cache = tree_cache
|
||||
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||
if isinstance(kv_cache, MHATokenToKVPool):
|
||||
self.decode_host_mem_pool = MHATokenToKVPoolHost(
|
||||
kv_cache,
|
||||
server_args.hicache_ratio,
|
||||
server_args.hicache_size,
|
||||
self.page_size,
|
||||
server_args.hicache_mem_layout,
|
||||
)
|
||||
elif isinstance(kv_cache, MLATokenToKVPool):
|
||||
self.decode_host_mem_pool = MLATokenToKVPoolHost(
|
||||
kv_cache,
|
||||
server_args.hicache_ratio,
|
||||
server_args.hicache_size,
|
||||
self.page_size,
|
||||
server_args.hicache_mem_layout,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unsupported KV cache type for decode offload")
|
||||
|
||||
self.tp_group = tp_group
|
||||
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
mem_pool_host=self.decode_host_mem_pool,
|
||||
page_size=self.page_size,
|
||||
tp_group=tp_group,
|
||||
io_backend=server_args.hicache_io_backend,
|
||||
load_cache_event=threading.Event(),
|
||||
storage_backend=server_args.hicache_storage_backend,
|
||||
model_name=server_args.served_model_name,
|
||||
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
||||
)
|
||||
|
||||
self.ongoing_offload = {}
|
||||
self.ongoing_backup = {}
|
||||
logger.info("Enable offload kv cache for decode side")
|
||||
|
||||
def offload_kv_cache(self, req) -> bool:
|
||||
"""Offload a finished request's KV cache to storage."""
|
||||
|
||||
if self.cache_controller is None or self.decode_host_mem_pool is None:
|
||||
return False
|
||||
|
||||
if req.req_pool_idx == -1:
|
||||
return False
|
||||
|
||||
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
||||
if token_indices.dim() == 0 or token_indices.numel() == 0:
|
||||
logger.debug(
|
||||
f"Request {req.rid} has invalid token_indices: {token_indices}"
|
||||
)
|
||||
return False
|
||||
|
||||
tokens = req.origin_input_ids + req.output_ids
|
||||
aligned_len = (len(tokens) // self.page_size) * self.page_size
|
||||
if aligned_len == 0:
|
||||
return False
|
||||
|
||||
token_indices = token_indices[:aligned_len]
|
||||
tokens = tokens[:aligned_len]
|
||||
|
||||
# Asynchronously offload KV cache from device to host by cache controller
|
||||
self.request_counter += 1
|
||||
ack_id = self.request_counter
|
||||
host_indices = self.cache_controller.write(
|
||||
device_indices=token_indices.long(),
|
||||
node_id=ack_id,
|
||||
)
|
||||
if host_indices is None:
|
||||
logger.error(f"Not enough host memory for request {req.rid}")
|
||||
return False
|
||||
|
||||
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
|
||||
return True
|
||||
|
||||
def check_offload_progress(self):
|
||||
"""Check the progress of offload from device to host and backup from host to storage."""
|
||||
cc = self.cache_controller
|
||||
|
||||
qsizes = torch.tensor(
|
||||
[
|
||||
len(cc.ack_write_queue),
|
||||
cc.ack_backup_queue.qsize(),
|
||||
],
|
||||
dtype=torch.int,
|
||||
)
|
||||
if self.tp_world_size > 1:
|
||||
torch.distributed.all_reduce(
|
||||
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
||||
)
|
||||
|
||||
n_write, n_backup = map(int, qsizes.tolist())
|
||||
self._check_offload_progress(n_write)
|
||||
self._check_backup_progress(n_backup)
|
||||
|
||||
def _check_offload_progress(self, finish_count):
|
||||
"""Check the progress of offload from device to host."""
|
||||
while finish_count > 0:
|
||||
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
||||
finish_event.synchronize()
|
||||
for ack_id in ack_list:
|
||||
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
|
||||
|
||||
# Release device
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
# Trigger async backup from host to storage by cache controller
|
||||
self._trigger_backup(req.rid, host_indices, tokens, start_time)
|
||||
finish_count -= 1
|
||||
|
||||
def _check_backup_progress(self, finish_count):
|
||||
"""Check the progress of backup from host to storage."""
|
||||
for _ in range(finish_count):
|
||||
storage_operation = self.cache_controller.ack_backup_queue.get()
|
||||
ack_id = storage_operation.id
|
||||
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
|
||||
|
||||
# Release host memory
|
||||
self.decode_host_mem_pool.free(host_indices)
|
||||
|
||||
logger.debug(
|
||||
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
|
||||
)
|
||||
|
||||
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
|
||||
"""Trigger async backup from host to storage by cache controller."""
|
||||
|
||||
# Generate page hashes and write to storage
|
||||
page_hashes = self._compute_prefix_hash(tokens)
|
||||
ack_id = self.cache_controller.write_storage(
|
||||
host_indices,
|
||||
tokens,
|
||||
hash_value=page_hashes,
|
||||
)
|
||||
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
|
||||
|
||||
def _compute_prefix_hash(self, tokens):
|
||||
last_hash = ""
|
||||
page_hashes = []
|
||||
for offset in range(0, len(tokens), self.page_size):
|
||||
page_tokens = tokens[offset : offset + self.page_size]
|
||||
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
|
||||
page_hashes.append(last_hash)
|
||||
return page_hashes
|
||||
@@ -107,7 +107,9 @@ class MetadataBuffers:
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||
|
||||
self.cached_tokens = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device=device
|
||||
)
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device=device
|
||||
)
|
||||
@@ -127,6 +129,7 @@ class MetadataBuffers:
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
self.output_ids.data_ptr(),
|
||||
self.cached_tokens.data_ptr(),
|
||||
self.output_token_logprobs_val.data_ptr(),
|
||||
self.output_token_logprobs_idx.data_ptr(),
|
||||
self.output_top_logprobs_val.data_ptr(),
|
||||
@@ -135,6 +138,7 @@ class MetadataBuffers:
|
||||
]
|
||||
data_lens = [
|
||||
self.output_ids.nbytes,
|
||||
self.cached_tokens.nbytes,
|
||||
self.output_token_logprobs_val.nbytes,
|
||||
self.output_token_logprobs_idx.nbytes,
|
||||
self.output_top_logprobs_val.nbytes,
|
||||
@@ -143,6 +147,7 @@ class MetadataBuffers:
|
||||
]
|
||||
item_lens = [
|
||||
self.output_ids[0].nbytes,
|
||||
self.cached_tokens[0].nbytes,
|
||||
self.output_token_logprobs_val[0].nbytes,
|
||||
self.output_token_logprobs_idx[0].nbytes,
|
||||
self.output_top_logprobs_val[0].nbytes,
|
||||
@@ -154,6 +159,7 @@ class MetadataBuffers:
|
||||
def get_buf(self, idx: int):
|
||||
return (
|
||||
self.output_ids[idx],
|
||||
self.cached_tokens[idx],
|
||||
self.output_token_logprobs_val[idx],
|
||||
self.output_token_logprobs_idx[idx],
|
||||
self.output_top_logprobs_val[idx],
|
||||
@@ -164,6 +170,7 @@ class MetadataBuffers:
|
||||
def set_buf(self, req: Req):
|
||||
|
||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
||||
if req.return_logprob:
|
||||
if req.output_token_logprobs_val: # not none or empty list
|
||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||
|
||||
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
|
||||
DecodeTransferQueue,
|
||||
SchedulerDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
||||
DecodeKVCacheOffloadManager,
|
||||
)
|
||||
from sglang.srt.disaggregation.prefill import (
|
||||
PrefillBootstrapQueue,
|
||||
SchedulerDisaggregationPrefillMixin,
|
||||
@@ -755,6 +758,24 @@ class Scheduler(
|
||||
eviction_policy=server_args.radix_eviction_policy,
|
||||
)
|
||||
|
||||
if (
|
||||
server_args.disaggregation_mode == "decode"
|
||||
and server_args.disaggregation_decode_enable_offload_kvcache
|
||||
):
|
||||
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
tp_group=(
|
||||
self.attn_tp_cpu_group
|
||||
if self.server_args.enable_dp_attention
|
||||
else self.tp_cpu_group
|
||||
),
|
||||
tree_cache=self.tree_cache,
|
||||
server_args=self.server_args,
|
||||
)
|
||||
else:
|
||||
self.decode_offload_manager = None
|
||||
|
||||
self.decode_mem_cache_buf_multiplier = (
|
||||
1
|
||||
if self.spec_algorithm.is_none()
|
||||
|
||||
@@ -250,7 +250,13 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
||||
if not self.decode_offload_manager.offload_kv_cache(req):
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
else:
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
req.time_stats.completion_time = time.time()
|
||||
|
||||
if req.return_logprob and batch.spec_algorithm.is_none():
|
||||
|
||||
@@ -421,6 +421,7 @@ class ServerArgs:
|
||||
disaggregation_decode_dp: Optional[int] = None
|
||||
disaggregation_prefill_pp: Optional[int] = 1
|
||||
disaggregation_ib_device: Optional[str] = None
|
||||
disaggregation_decode_enable_offload_kvcache: bool = False
|
||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||
# FIXME: hack to reduce ITL when decode bs is small
|
||||
disaggregation_decode_polling_interval: int = 1
|
||||
@@ -1074,6 +1075,14 @@ class ServerArgs:
|
||||
"and cannot be used at the same time. Please use only one of them."
|
||||
)
|
||||
|
||||
if (
|
||||
self.disaggregation_decode_enable_offload_kvcache
|
||||
and self.disaggregation_mode != "decode"
|
||||
):
|
||||
raise ValueError(
|
||||
"The argument disaggregation-decode-enable-offload-kvcache is only supported for decode side."
|
||||
)
|
||||
|
||||
def _handle_metrics_labels(self):
|
||||
if (
|
||||
not self.tokenizer_metrics_custom_labels_header
|
||||
@@ -2556,6 +2565,11 @@ class ServerArgs:
|
||||
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
|
||||
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disaggregation-decode-enable-offload-kvcache",
|
||||
action="store_true",
|
||||
help="Enable async KV cache offloading on decode server (PD mode).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-reserved-decode-tokens",
|
||||
type=int,
|
||||
|
||||
Reference in New Issue
Block a user