From 7ec5b4e89c16dc91bfee64099fa8d7ad7bb538a2 Mon Sep 17 00:00:00 2001 From: hzh0425 Date: Fri, 26 Sep 2025 14:20:49 +0800 Subject: [PATCH] [PD-HiCache]: Support Async Offloading KVCache In Decode Side (#10192) Signed-off-by: Shangming Cai Co-authored-by: Shangming Cai --- python/sglang/srt/disaggregation/decode.py | 23 +- .../decode_kvcache_offload_manager.py | 185 ++++++++++++ python/sglang/srt/disaggregation/utils.py | 9 +- python/sglang/srt/managers/scheduler.py | 21 ++ .../scheduler_output_processor_mixin.py | 8 +- python/sglang/srt/server_args.py | 14 + .../hicache/test_disaggregation_hicache.py | 271 ++++++++++++++++++ 7 files changed, 523 insertions(+), 8 deletions(-) create mode 100644 python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py create mode 100644 test/srt/hicache/test_disaggregation_hicache.py diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index f4d7e8f7f..32128f480 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -609,6 +609,7 @@ class DecodeTransferQueue: idx = decode_req.metadata_buffer_index ( output_id, + cached_tokens, output_token_logprobs_val, output_token_logprobs_idx, output_top_logprobs_val, @@ -617,6 +618,7 @@ class DecodeTransferQueue: ) = self.metadata_buffers.get_buf(idx) decode_req.req.output_ids.append(output_id[0].item()) + decode_req.req.cached_tokens = cached_tokens[0].item() if not self.spec_algorithm.is_none(): decode_req.req.hidden_states_tensor = output_hidden_states if decode_req.req.return_logprob: @@ -707,12 +709,15 @@ class SchedulerDisaggregationDecodeMixin: elif prepare_mlp_sync_flag: batch, _ = self._prepare_idle_batch_and_run(None) - if batch is None and ( + queue_size = ( len(self.waiting_queue) + len(self.disagg_decode_transfer_queue.queue) + len(self.disagg_decode_prealloc_queue.queue) - == 0 - ): + ) + if self.server_args.disaggregation_decode_enable_offload_kvcache: + queue_size += len(self.decode_offload_manager.ongoing_offload) + + if batch is None and queue_size == 0: self.self_check_during_idle() self.last_batch = batch @@ -781,12 +786,15 @@ class SchedulerDisaggregationDecodeMixin: ) self.process_batch_result(tmp_batch, tmp_result) - if batch is None and ( + queue_size = ( len(self.waiting_queue) + len(self.disagg_decode_transfer_queue.queue) + len(self.disagg_decode_prealloc_queue.queue) - == 0 - ): + ) + if self.server_args.disaggregation_decode_enable_offload_kvcache: + queue_size += len(self.decode_offload_manager.ongoing_offload) + + if batch is None and queue_size == 0: self.self_check_during_idle() self.last_batch = batch @@ -905,3 +913,6 @@ class SchedulerDisaggregationDecodeMixin: self.disagg_decode_transfer_queue.pop_transferred() ) # the requests which kv has arrived self.waiting_queue.extend(alloc_reqs) + + if self.server_args.disaggregation_decode_enable_offload_kvcache: + self.decode_offload_manager.check_offload_progress() diff --git a/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py b/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py new file mode 100644 index 000000000..c74b4938e --- /dev/null +++ b/python/sglang/srt/disaggregation/decode_kvcache_offload_manager.py @@ -0,0 +1,185 @@ +import logging +import threading +import time + +import torch + +from sglang import ServerArgs +from sglang.srt.managers.cache_controller import HiCacheController +from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator +from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.memory_pool import ( + MHATokenToKVPool, + MLATokenToKVPool, + ReqToTokenPool, +) +from sglang.srt.mem_cache.memory_pool_host import ( + MHATokenToKVPoolHost, + MLATokenToKVPoolHost, +) + +logger = logging.getLogger(__name__) + + +class DecodeKVCacheOffloadManager: + """Manage decode-side KV cache offloading lifecycle and operations.""" + + def __init__( + self, + req_to_token_pool: ReqToTokenPool, + token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, + tp_group: torch.distributed.ProcessGroup, + tree_cache: BasePrefixCache, + server_args: ServerArgs, + ) -> None: + self.req_to_token_pool = req_to_token_pool + self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.page_size = server_args.page_size + self.server_args = server_args + self.request_counter = 0 + self.tree_cache = tree_cache + kv_cache = self.token_to_kv_pool_allocator.get_kvcache() + if isinstance(kv_cache, MHATokenToKVPool): + self.decode_host_mem_pool = MHATokenToKVPoolHost( + kv_cache, + server_args.hicache_ratio, + server_args.hicache_size, + self.page_size, + server_args.hicache_mem_layout, + ) + elif isinstance(kv_cache, MLATokenToKVPool): + self.decode_host_mem_pool = MLATokenToKVPoolHost( + kv_cache, + server_args.hicache_ratio, + server_args.hicache_size, + self.page_size, + server_args.hicache_mem_layout, + ) + else: + raise ValueError("Unsupported KV cache type for decode offload") + + self.tp_group = tp_group + self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) + self.cache_controller = HiCacheController( + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + mem_pool_host=self.decode_host_mem_pool, + page_size=self.page_size, + tp_group=tp_group, + io_backend=server_args.hicache_io_backend, + load_cache_event=threading.Event(), + storage_backend=server_args.hicache_storage_backend, + model_name=server_args.served_model_name, + storage_backend_extra_config=server_args.hicache_storage_backend_extra_config, + ) + + self.ongoing_offload = {} + self.ongoing_backup = {} + logger.info("Enable offload kv cache for decode side") + + def offload_kv_cache(self, req) -> bool: + """Offload a finished request's KV cache to storage.""" + + if self.cache_controller is None or self.decode_host_mem_pool is None: + return False + + if req.req_pool_idx == -1: + return False + + token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx] + if token_indices.dim() == 0 or token_indices.numel() == 0: + logger.debug( + f"Request {req.rid} has invalid token_indices: {token_indices}" + ) + return False + + tokens = req.origin_input_ids + req.output_ids + aligned_len = (len(tokens) // self.page_size) * self.page_size + if aligned_len == 0: + return False + + token_indices = token_indices[:aligned_len] + tokens = tokens[:aligned_len] + + # Asynchronously offload KV cache from device to host by cache controller + self.request_counter += 1 + ack_id = self.request_counter + host_indices = self.cache_controller.write( + device_indices=token_indices.long(), + node_id=ack_id, + ) + if host_indices is None: + logger.error(f"Not enough host memory for request {req.rid}") + return False + + self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time()) + return True + + def check_offload_progress(self): + """Check the progress of offload from device to host and backup from host to storage.""" + cc = self.cache_controller + + qsizes = torch.tensor( + [ + len(cc.ack_write_queue), + cc.ack_backup_queue.qsize(), + ], + dtype=torch.int, + ) + if self.tp_world_size > 1: + torch.distributed.all_reduce( + qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group + ) + + n_write, n_backup = map(int, qsizes.tolist()) + self._check_offload_progress(n_write) + self._check_backup_progress(n_backup) + + def _check_offload_progress(self, finish_count): + """Check the progress of offload from device to host.""" + while finish_count > 0: + _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0) + finish_event.synchronize() + for ack_id in ack_list: + req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id) + + # Release device + self.tree_cache.cache_finished_req(req) + + # Trigger async backup from host to storage by cache controller + self._trigger_backup(req.rid, host_indices, tokens, start_time) + finish_count -= 1 + + def _check_backup_progress(self, finish_count): + """Check the progress of backup from host to storage.""" + for _ in range(finish_count): + storage_operation = self.cache_controller.ack_backup_queue.get() + ack_id = storage_operation.id + req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id) + + # Release host memory + self.decode_host_mem_pool.free(host_indices) + + logger.debug( + f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds." + ) + + def _trigger_backup(self, req_id, host_indices, tokens, start_time): + """Trigger async backup from host to storage by cache controller.""" + + # Generate page hashes and write to storage + page_hashes = self._compute_prefix_hash(tokens) + ack_id = self.cache_controller.write_storage( + host_indices, + tokens, + hash_value=page_hashes, + ) + self.ongoing_backup[ack_id] = (req_id, host_indices, start_time) + + def _compute_prefix_hash(self, tokens): + last_hash = "" + page_hashes = [] + for offset in range(0, len(tokens), self.page_size): + page_tokens = tokens[offset : offset + self.page_size] + last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash) + page_hashes.append(last_hash) + return page_hashes diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 43770e3e2..1ea1cc6c6 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -107,7 +107,9 @@ class MetadataBuffers: # We transfer the metadata of first output token to decode # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) - + self.cached_tokens = torch.zeros( + (size, 16), dtype=torch.int32, device=device + ) self.output_token_logprobs_val = torch.zeros( (size, 16), dtype=torch.float32, device=device ) @@ -127,6 +129,7 @@ class MetadataBuffers: def get_buf_infos(self): ptrs = [ self.output_ids.data_ptr(), + self.cached_tokens.data_ptr(), self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_idx.data_ptr(), self.output_top_logprobs_val.data_ptr(), @@ -135,6 +138,7 @@ class MetadataBuffers: ] data_lens = [ self.output_ids.nbytes, + self.cached_tokens.nbytes, self.output_token_logprobs_val.nbytes, self.output_token_logprobs_idx.nbytes, self.output_top_logprobs_val.nbytes, @@ -143,6 +147,7 @@ class MetadataBuffers: ] item_lens = [ self.output_ids[0].nbytes, + self.cached_tokens[0].nbytes, self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_idx[0].nbytes, self.output_top_logprobs_val[0].nbytes, @@ -154,6 +159,7 @@ class MetadataBuffers: def get_buf(self, idx: int): return ( self.output_ids[idx], + self.cached_tokens[idx], self.output_token_logprobs_val[idx], self.output_token_logprobs_idx[idx], self.output_top_logprobs_val[idx], @@ -164,6 +170,7 @@ class MetadataBuffers: def set_buf(self, req: Req): self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] + self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens if req.return_logprob: if req.output_token_logprobs_val: # not none or empty list self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 31294749d..4404e1fc6 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import ( DecodeTransferQueue, SchedulerDisaggregationDecodeMixin, ) +from sglang.srt.disaggregation.decode_kvcache_offload_manager import ( + DecodeKVCacheOffloadManager, +) from sglang.srt.disaggregation.prefill import ( PrefillBootstrapQueue, SchedulerDisaggregationPrefillMixin, @@ -755,6 +758,24 @@ class Scheduler( eviction_policy=server_args.radix_eviction_policy, ) + if ( + server_args.disaggregation_mode == "decode" + and server_args.disaggregation_decode_enable_offload_kvcache + ): + self.decode_offload_manager = DecodeKVCacheOffloadManager( + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, + tp_group=( + self.attn_tp_cpu_group + if self.server_args.enable_dp_attention + else self.tp_cpu_group + ), + tree_cache=self.tree_cache, + server_args=self.server_args, + ) + else: + self.decode_offload_manager = None + self.decode_mem_cache_buf_multiplier = ( 1 if self.spec_algorithm.is_none() diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index aa060af8a..5d8545dac 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -250,7 +250,13 @@ class SchedulerOutputProcessorMixin: req.check_finished() if req.finished(): - self.tree_cache.cache_finished_req(req) + if self.server_args.disaggregation_decode_enable_offload_kvcache: + # Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes + if not self.decode_offload_manager.offload_kv_cache(req): + self.tree_cache.cache_finished_req(req) + else: + self.tree_cache.cache_finished_req(req) + req.time_stats.completion_time = time.time() if req.return_logprob and batch.spec_algorithm.is_none(): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 020f71f11..6af1845c1 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -421,6 +421,7 @@ class ServerArgs: disaggregation_decode_dp: Optional[int] = None disaggregation_prefill_pp: Optional[int] = 1 disaggregation_ib_device: Optional[str] = None + disaggregation_decode_enable_offload_kvcache: bool = False num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD # FIXME: hack to reduce ITL when decode bs is small disaggregation_decode_polling_interval: int = 1 @@ -1074,6 +1075,14 @@ class ServerArgs: "and cannot be used at the same time. Please use only one of them." ) + if ( + self.disaggregation_decode_enable_offload_kvcache + and self.disaggregation_mode != "decode" + ): + raise ValueError( + "The argument disaggregation-decode-enable-offload-kvcache is only supported for decode side." + ) + def _handle_metrics_labels(self): if ( not self.tokenizer_metrics_custom_labels_header @@ -2556,6 +2565,11 @@ class ServerArgs: "or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). " "Default is None, which triggers automatic device detection when mooncake backend is enabled.", ) + parser.add_argument( + "--disaggregation-decode-enable-offload-kvcache", + action="store_true", + help="Enable async KV cache offloading on decode server (PD mode).", + ) parser.add_argument( "--num-reserved-decode-tokens", type=int, diff --git a/test/srt/hicache/test_disaggregation_hicache.py b/test/srt/hicache/test_disaggregation_hicache.py new file mode 100644 index 000000000..1b4015054 --- /dev/null +++ b/test/srt/hicache/test_disaggregation_hicache.py @@ -0,0 +1,271 @@ +import os +import random +import tempfile +import time +import unittest +from typing import Dict +from urllib.parse import urlparse + +import requests + +from sglang.bench_serving import get_tokenizer +from sglang.test.test_disaggregation_utils import TestDisaggregationBase +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_pd_server, +) + + +class DisaggregationHiCacheBase(TestDisaggregationBase): + """Base class for disaggregation with HiCache tests""" + + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + parsed_url = urlparse(DEFAULT_URL_FOR_TEST) + cls.base_host = parsed_url.hostname + base_port = str(parsed_url.port) + cls.lb_port = base_port + cls.prefill_port = f"{int(base_port) + 100}" + cls.decode_port = f"{int(base_port) + 200}" + cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}" + cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}" + cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}" + print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}") + + cls.tokenizer = get_tokenizer(cls.model) + cls.temp_dir = tempfile.mkdtemp() + cls.start_prefill() + cls.start_decode() + + # Block until both + cls.wait_server_ready(cls.prefill_url + "/health") + cls.wait_server_ready(cls.decode_url + "/health") + + cls.launch_lb() + + @classmethod + def start_prefill(cls): + # Prefill with HiCache enabled + prefill_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "prefill", + "--tp-size", + "1", + "--page-size", + "64", + "--enable-hierarchical-cache", + "--hicache-ratio", + "1.2", + "--hicache-size", + "0", + "--hicache-write-policy", + "write_through", + "--hicache-storage-backend", + "file", + "--hicache-storage-prefetch-policy", + "wait_complete", + "--mem-fraction-static", + "0.8", + "--disaggregation-ib-device", + "mlx5_roce0", + "--disaggregation-transfer-backend", + "mooncake", + ] + env = { + **os.environ, + "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, + } + cls.process_prefill = popen_launch_pd_server( + cls.model, + cls.prefill_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=prefill_args, + env=env, + ) + + @classmethod + def start_decode(cls): + pass + + def gen_prompt(self, token_num: int) -> str: + all_available_tokens = list(self.tokenizer.get_vocab().values()) + selected_tokens = random.choices(all_available_tokens, k=token_num) + return self.tokenizer.decode(selected_tokens) + + def send_request( + self, prompt: str, max_tokens: int = 100, temperature: float = 0.0 + ) -> Dict: + """Send a generate request and return response""" + response = requests.post( + f"{self.lb_url}/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "ignore_eos": True, + }, + }, + timeout=60, + ) + + self.assertEqual( + response.status_code, + 200, + f"Request failed: {response.status_code} - {response.text}", + ) + return response.json() + + def trigger_offloading_and_flush(self): + """Helper method to trigger offloading and flush cache""" + # Trigger offloading + self.send_request(self.gen_prompt(1), max_tokens=150) + + # Flush device cache to force remote storage access + time.sleep(2) + requests.post(self.prefill_url + "/flush_cache") + + +class TestDisaggregationPrefillWithHiCache(DisaggregationHiCacheBase): + """Test disaggregation with HiCache enabled only on Prefill side""" + + @classmethod + def start_decode(cls): + # Decode without HiCache offload + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp-size", + "1", + "--page-size", + "64", + "--mem-fraction-static", + "0.8", + "--base-gpu-id", + "1", + "--disaggregation-ib-device", + "mlx5_roce0", + "--disaggregation-transfer-backend", + "mooncake", + ] + env = { + **os.environ, + "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, + } + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + env=env, + ) + + def test_prefill_cache_hit(self): + """Test that prefill cache works with repeated queries""" + + repeated_prompt = self.gen_prompt(800) + + # First request - should miss cache + self.send_request(repeated_prompt, max_tokens=100) + + # Flush cache + self.trigger_offloading_and_flush() + + # Second request - should hit cache (faster) + response2 = self.send_request(repeated_prompt, max_tokens=100) + + # Assert cached tokens cnt + self.assertGreater(response2["meta_info"]["cached_tokens"], 700) + + +class TestDisaggregationDecodeWithHiCache(DisaggregationHiCacheBase): + """Test disaggregation with HiCache enabled on both Prefill and Decode sides""" + + @classmethod + def start_decode(cls): + # Decode with HiCache offload enabled + decode_args = [ + "--trust-remote-code", + "--disaggregation-mode", + "decode", + "--tp-size", + "1", + "--page-size", + "64", + "--mem-fraction-static", + "0.8", + "--base-gpu-id", + "1", + "--disaggregation-ib-device", + "mlx5_roce0", + "--disaggregation-transfer-backend", + "mooncake", + "--disaggregation-decode-enable-offload-kvcache", + "--hicache-ratio", + "1.2", + "--hicache-size", + "0", + "--hicache-storage-backend", + "file", + "--hicache-storage-prefetch-policy", + "wait_complete", + ] + env = { + **os.environ, + "SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir, + } + cls.process_decode = popen_launch_pd_server( + cls.model, + cls.decode_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=decode_args, + env=env, + ) + + def test_multi_turn_conversation_cache(self): + """Test multi-turn conversation scenario with cache hit improvement""" + + print("=== Multi-turn Conversation Cache Test ===") + + # Turn 1 + initial_prompt = self.gen_prompt(300) + + response1 = self.send_request(initial_prompt, max_tokens=200, temperature=0.1) + current_context = initial_prompt + response1["text"] + + # Turns 2-4: Continue generation based on previous context + previous_cached_tokens = 0 + + for turn in range(2, 5): + print(f"\nTurn {turn}: Continuing from previous context") + + response = self.send_request( + current_context, max_tokens=200, temperature=0.1 + ) + cached_tokens = response["meta_info"]["cached_tokens"] + + print(f"Turn {turn} cached tokens: {cached_tokens}") + print(f"Improvement: {cached_tokens - previous_cached_tokens} tokens") + + # Assert cache improvement + self.assertGreater( + cached_tokens, + previous_cached_tokens, + f"Turn {turn} should have more cached tokens than turn {turn-1}", + ) + + # Update context and cached tokens for next iteration + current_context += response["text"] + previous_cached_tokens = cached_tokens + + # Flush prefill cache + self.trigger_offloading_and_flush() + + +if __name__ == "__main__": + unittest.main()