[PD-HiCache]: Support Async Offloading KVCache In Decode Side (#10192)
Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -609,6 +609,7 @@ class DecodeTransferQueue:
|
|||||||
idx = decode_req.metadata_buffer_index
|
idx = decode_req.metadata_buffer_index
|
||||||
(
|
(
|
||||||
output_id,
|
output_id,
|
||||||
|
cached_tokens,
|
||||||
output_token_logprobs_val,
|
output_token_logprobs_val,
|
||||||
output_token_logprobs_idx,
|
output_token_logprobs_idx,
|
||||||
output_top_logprobs_val,
|
output_top_logprobs_val,
|
||||||
@@ -617,6 +618,7 @@ class DecodeTransferQueue:
|
|||||||
) = self.metadata_buffers.get_buf(idx)
|
) = self.metadata_buffers.get_buf(idx)
|
||||||
|
|
||||||
decode_req.req.output_ids.append(output_id[0].item())
|
decode_req.req.output_ids.append(output_id[0].item())
|
||||||
|
decode_req.req.cached_tokens = cached_tokens[0].item()
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
decode_req.req.hidden_states_tensor = output_hidden_states
|
decode_req.req.hidden_states_tensor = output_hidden_states
|
||||||
if decode_req.req.return_logprob:
|
if decode_req.req.return_logprob:
|
||||||
@@ -707,12 +709,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
elif prepare_mlp_sync_flag:
|
elif prepare_mlp_sync_flag:
|
||||||
batch, _ = self._prepare_idle_batch_and_run(None)
|
batch, _ = self._prepare_idle_batch_and_run(None)
|
||||||
|
|
||||||
if batch is None and (
|
queue_size = (
|
||||||
len(self.waiting_queue)
|
len(self.waiting_queue)
|
||||||
+ len(self.disagg_decode_transfer_queue.queue)
|
+ len(self.disagg_decode_transfer_queue.queue)
|
||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
== 0
|
)
|
||||||
):
|
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||||
|
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
||||||
|
|
||||||
|
if batch is None and queue_size == 0:
|
||||||
self.self_check_during_idle()
|
self.self_check_during_idle()
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
@@ -781,12 +786,15 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
)
|
)
|
||||||
self.process_batch_result(tmp_batch, tmp_result)
|
self.process_batch_result(tmp_batch, tmp_result)
|
||||||
|
|
||||||
if batch is None and (
|
queue_size = (
|
||||||
len(self.waiting_queue)
|
len(self.waiting_queue)
|
||||||
+ len(self.disagg_decode_transfer_queue.queue)
|
+ len(self.disagg_decode_transfer_queue.queue)
|
||||||
+ len(self.disagg_decode_prealloc_queue.queue)
|
+ len(self.disagg_decode_prealloc_queue.queue)
|
||||||
== 0
|
)
|
||||||
):
|
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||||
|
queue_size += len(self.decode_offload_manager.ongoing_offload)
|
||||||
|
|
||||||
|
if batch is None and queue_size == 0:
|
||||||
self.self_check_during_idle()
|
self.self_check_during_idle()
|
||||||
|
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
@@ -905,3 +913,6 @@ class SchedulerDisaggregationDecodeMixin:
|
|||||||
self.disagg_decode_transfer_queue.pop_transferred()
|
self.disagg_decode_transfer_queue.pop_transferred()
|
||||||
) # the requests which kv has arrived
|
) # the requests which kv has arrived
|
||||||
self.waiting_queue.extend(alloc_reqs)
|
self.waiting_queue.extend(alloc_reqs)
|
||||||
|
|
||||||
|
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||||
|
self.decode_offload_manager.check_offload_progress()
|
||||||
|
|||||||
@@ -0,0 +1,185 @@
|
|||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang import ServerArgs
|
||||||
|
from sglang.srt.managers.cache_controller import HiCacheController
|
||||||
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||||
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
|
MHATokenToKVPool,
|
||||||
|
MLATokenToKVPool,
|
||||||
|
ReqToTokenPool,
|
||||||
|
)
|
||||||
|
from sglang.srt.mem_cache.memory_pool_host import (
|
||||||
|
MHATokenToKVPoolHost,
|
||||||
|
MLATokenToKVPoolHost,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeKVCacheOffloadManager:
|
||||||
|
"""Manage decode-side KV cache offloading lifecycle and operations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
|
||||||
|
tp_group: torch.distributed.ProcessGroup,
|
||||||
|
tree_cache: BasePrefixCache,
|
||||||
|
server_args: ServerArgs,
|
||||||
|
) -> None:
|
||||||
|
self.req_to_token_pool = req_to_token_pool
|
||||||
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
|
self.page_size = server_args.page_size
|
||||||
|
self.server_args = server_args
|
||||||
|
self.request_counter = 0
|
||||||
|
self.tree_cache = tree_cache
|
||||||
|
kv_cache = self.token_to_kv_pool_allocator.get_kvcache()
|
||||||
|
if isinstance(kv_cache, MHATokenToKVPool):
|
||||||
|
self.decode_host_mem_pool = MHATokenToKVPoolHost(
|
||||||
|
kv_cache,
|
||||||
|
server_args.hicache_ratio,
|
||||||
|
server_args.hicache_size,
|
||||||
|
self.page_size,
|
||||||
|
server_args.hicache_mem_layout,
|
||||||
|
)
|
||||||
|
elif isinstance(kv_cache, MLATokenToKVPool):
|
||||||
|
self.decode_host_mem_pool = MLATokenToKVPoolHost(
|
||||||
|
kv_cache,
|
||||||
|
server_args.hicache_ratio,
|
||||||
|
server_args.hicache_size,
|
||||||
|
self.page_size,
|
||||||
|
server_args.hicache_mem_layout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported KV cache type for decode offload")
|
||||||
|
|
||||||
|
self.tp_group = tp_group
|
||||||
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
||||||
|
self.cache_controller = HiCacheController(
|
||||||
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
mem_pool_host=self.decode_host_mem_pool,
|
||||||
|
page_size=self.page_size,
|
||||||
|
tp_group=tp_group,
|
||||||
|
io_backend=server_args.hicache_io_backend,
|
||||||
|
load_cache_event=threading.Event(),
|
||||||
|
storage_backend=server_args.hicache_storage_backend,
|
||||||
|
model_name=server_args.served_model_name,
|
||||||
|
storage_backend_extra_config=server_args.hicache_storage_backend_extra_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ongoing_offload = {}
|
||||||
|
self.ongoing_backup = {}
|
||||||
|
logger.info("Enable offload kv cache for decode side")
|
||||||
|
|
||||||
|
def offload_kv_cache(self, req) -> bool:
|
||||||
|
"""Offload a finished request's KV cache to storage."""
|
||||||
|
|
||||||
|
if self.cache_controller is None or self.decode_host_mem_pool is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if req.req_pool_idx == -1:
|
||||||
|
return False
|
||||||
|
|
||||||
|
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
|
||||||
|
if token_indices.dim() == 0 or token_indices.numel() == 0:
|
||||||
|
logger.debug(
|
||||||
|
f"Request {req.rid} has invalid token_indices: {token_indices}"
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
tokens = req.origin_input_ids + req.output_ids
|
||||||
|
aligned_len = (len(tokens) // self.page_size) * self.page_size
|
||||||
|
if aligned_len == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
|
token_indices = token_indices[:aligned_len]
|
||||||
|
tokens = tokens[:aligned_len]
|
||||||
|
|
||||||
|
# Asynchronously offload KV cache from device to host by cache controller
|
||||||
|
self.request_counter += 1
|
||||||
|
ack_id = self.request_counter
|
||||||
|
host_indices = self.cache_controller.write(
|
||||||
|
device_indices=token_indices.long(),
|
||||||
|
node_id=ack_id,
|
||||||
|
)
|
||||||
|
if host_indices is None:
|
||||||
|
logger.error(f"Not enough host memory for request {req.rid}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time())
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_offload_progress(self):
|
||||||
|
"""Check the progress of offload from device to host and backup from host to storage."""
|
||||||
|
cc = self.cache_controller
|
||||||
|
|
||||||
|
qsizes = torch.tensor(
|
||||||
|
[
|
||||||
|
len(cc.ack_write_queue),
|
||||||
|
cc.ack_backup_queue.qsize(),
|
||||||
|
],
|
||||||
|
dtype=torch.int,
|
||||||
|
)
|
||||||
|
if self.tp_world_size > 1:
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
|
||||||
|
)
|
||||||
|
|
||||||
|
n_write, n_backup = map(int, qsizes.tolist())
|
||||||
|
self._check_offload_progress(n_write)
|
||||||
|
self._check_backup_progress(n_backup)
|
||||||
|
|
||||||
|
def _check_offload_progress(self, finish_count):
|
||||||
|
"""Check the progress of offload from device to host."""
|
||||||
|
while finish_count > 0:
|
||||||
|
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
|
||||||
|
finish_event.synchronize()
|
||||||
|
for ack_id in ack_list:
|
||||||
|
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id)
|
||||||
|
|
||||||
|
# Release device
|
||||||
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
|
||||||
|
# Trigger async backup from host to storage by cache controller
|
||||||
|
self._trigger_backup(req.rid, host_indices, tokens, start_time)
|
||||||
|
finish_count -= 1
|
||||||
|
|
||||||
|
def _check_backup_progress(self, finish_count):
|
||||||
|
"""Check the progress of backup from host to storage."""
|
||||||
|
for _ in range(finish_count):
|
||||||
|
storage_operation = self.cache_controller.ack_backup_queue.get()
|
||||||
|
ack_id = storage_operation.id
|
||||||
|
req_id, host_indices, start_time = self.ongoing_backup.pop(ack_id)
|
||||||
|
|
||||||
|
# Release host memory
|
||||||
|
self.decode_host_mem_pool.free(host_indices)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _trigger_backup(self, req_id, host_indices, tokens, start_time):
|
||||||
|
"""Trigger async backup from host to storage by cache controller."""
|
||||||
|
|
||||||
|
# Generate page hashes and write to storage
|
||||||
|
page_hashes = self._compute_prefix_hash(tokens)
|
||||||
|
ack_id = self.cache_controller.write_storage(
|
||||||
|
host_indices,
|
||||||
|
tokens,
|
||||||
|
hash_value=page_hashes,
|
||||||
|
)
|
||||||
|
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time)
|
||||||
|
|
||||||
|
def _compute_prefix_hash(self, tokens):
|
||||||
|
last_hash = ""
|
||||||
|
page_hashes = []
|
||||||
|
for offset in range(0, len(tokens), self.page_size):
|
||||||
|
page_tokens = tokens[offset : offset + self.page_size]
|
||||||
|
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
|
||||||
|
page_hashes.append(last_hash)
|
||||||
|
return page_hashes
|
||||||
@@ -107,7 +107,9 @@ class MetadataBuffers:
|
|||||||
# We transfer the metadata of first output token to decode
|
# We transfer the metadata of first output token to decode
|
||||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
||||||
|
self.cached_tokens = torch.zeros(
|
||||||
|
(size, 16), dtype=torch.int32, device=device
|
||||||
|
)
|
||||||
self.output_token_logprobs_val = torch.zeros(
|
self.output_token_logprobs_val = torch.zeros(
|
||||||
(size, 16), dtype=torch.float32, device=device
|
(size, 16), dtype=torch.float32, device=device
|
||||||
)
|
)
|
||||||
@@ -127,6 +129,7 @@ class MetadataBuffers:
|
|||||||
def get_buf_infos(self):
|
def get_buf_infos(self):
|
||||||
ptrs = [
|
ptrs = [
|
||||||
self.output_ids.data_ptr(),
|
self.output_ids.data_ptr(),
|
||||||
|
self.cached_tokens.data_ptr(),
|
||||||
self.output_token_logprobs_val.data_ptr(),
|
self.output_token_logprobs_val.data_ptr(),
|
||||||
self.output_token_logprobs_idx.data_ptr(),
|
self.output_token_logprobs_idx.data_ptr(),
|
||||||
self.output_top_logprobs_val.data_ptr(),
|
self.output_top_logprobs_val.data_ptr(),
|
||||||
@@ -135,6 +138,7 @@ class MetadataBuffers:
|
|||||||
]
|
]
|
||||||
data_lens = [
|
data_lens = [
|
||||||
self.output_ids.nbytes,
|
self.output_ids.nbytes,
|
||||||
|
self.cached_tokens.nbytes,
|
||||||
self.output_token_logprobs_val.nbytes,
|
self.output_token_logprobs_val.nbytes,
|
||||||
self.output_token_logprobs_idx.nbytes,
|
self.output_token_logprobs_idx.nbytes,
|
||||||
self.output_top_logprobs_val.nbytes,
|
self.output_top_logprobs_val.nbytes,
|
||||||
@@ -143,6 +147,7 @@ class MetadataBuffers:
|
|||||||
]
|
]
|
||||||
item_lens = [
|
item_lens = [
|
||||||
self.output_ids[0].nbytes,
|
self.output_ids[0].nbytes,
|
||||||
|
self.cached_tokens[0].nbytes,
|
||||||
self.output_token_logprobs_val[0].nbytes,
|
self.output_token_logprobs_val[0].nbytes,
|
||||||
self.output_token_logprobs_idx[0].nbytes,
|
self.output_token_logprobs_idx[0].nbytes,
|
||||||
self.output_top_logprobs_val[0].nbytes,
|
self.output_top_logprobs_val[0].nbytes,
|
||||||
@@ -154,6 +159,7 @@ class MetadataBuffers:
|
|||||||
def get_buf(self, idx: int):
|
def get_buf(self, idx: int):
|
||||||
return (
|
return (
|
||||||
self.output_ids[idx],
|
self.output_ids[idx],
|
||||||
|
self.cached_tokens[idx],
|
||||||
self.output_token_logprobs_val[idx],
|
self.output_token_logprobs_val[idx],
|
||||||
self.output_token_logprobs_idx[idx],
|
self.output_token_logprobs_idx[idx],
|
||||||
self.output_top_logprobs_val[idx],
|
self.output_top_logprobs_val[idx],
|
||||||
@@ -164,6 +170,7 @@ class MetadataBuffers:
|
|||||||
def set_buf(self, req: Req):
|
def set_buf(self, req: Req):
|
||||||
|
|
||||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||||
|
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
if req.output_token_logprobs_val: # not none or empty list
|
if req.output_token_logprobs_val: # not none or empty list
|
||||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||||
|
|||||||
@@ -44,6 +44,9 @@ from sglang.srt.disaggregation.decode import (
|
|||||||
DecodeTransferQueue,
|
DecodeTransferQueue,
|
||||||
SchedulerDisaggregationDecodeMixin,
|
SchedulerDisaggregationDecodeMixin,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
|
||||||
|
DecodeKVCacheOffloadManager,
|
||||||
|
)
|
||||||
from sglang.srt.disaggregation.prefill import (
|
from sglang.srt.disaggregation.prefill import (
|
||||||
PrefillBootstrapQueue,
|
PrefillBootstrapQueue,
|
||||||
SchedulerDisaggregationPrefillMixin,
|
SchedulerDisaggregationPrefillMixin,
|
||||||
@@ -755,6 +758,24 @@ class Scheduler(
|
|||||||
eviction_policy=server_args.radix_eviction_policy,
|
eviction_policy=server_args.radix_eviction_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
server_args.disaggregation_mode == "decode"
|
||||||
|
and server_args.disaggregation_decode_enable_offload_kvcache
|
||||||
|
):
|
||||||
|
self.decode_offload_manager = DecodeKVCacheOffloadManager(
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
|
tp_group=(
|
||||||
|
self.attn_tp_cpu_group
|
||||||
|
if self.server_args.enable_dp_attention
|
||||||
|
else self.tp_cpu_group
|
||||||
|
),
|
||||||
|
tree_cache=self.tree_cache,
|
||||||
|
server_args=self.server_args,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.decode_offload_manager = None
|
||||||
|
|
||||||
self.decode_mem_cache_buf_multiplier = (
|
self.decode_mem_cache_buf_multiplier = (
|
||||||
1
|
1
|
||||||
if self.spec_algorithm.is_none()
|
if self.spec_algorithm.is_none()
|
||||||
|
|||||||
@@ -250,7 +250,13 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.tree_cache.cache_finished_req(req)
|
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||||
|
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
||||||
|
if not self.decode_offload_manager.offload_kv_cache(req):
|
||||||
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
else:
|
||||||
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
|
||||||
req.time_stats.completion_time = time.time()
|
req.time_stats.completion_time = time.time()
|
||||||
|
|
||||||
if req.return_logprob and batch.spec_algorithm.is_none():
|
if req.return_logprob and batch.spec_algorithm.is_none():
|
||||||
|
|||||||
@@ -421,6 +421,7 @@ class ServerArgs:
|
|||||||
disaggregation_decode_dp: Optional[int] = None
|
disaggregation_decode_dp: Optional[int] = None
|
||||||
disaggregation_prefill_pp: Optional[int] = 1
|
disaggregation_prefill_pp: Optional[int] = 1
|
||||||
disaggregation_ib_device: Optional[str] = None
|
disaggregation_ib_device: Optional[str] = None
|
||||||
|
disaggregation_decode_enable_offload_kvcache: bool = False
|
||||||
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
||||||
# FIXME: hack to reduce ITL when decode bs is small
|
# FIXME: hack to reduce ITL when decode bs is small
|
||||||
disaggregation_decode_polling_interval: int = 1
|
disaggregation_decode_polling_interval: int = 1
|
||||||
@@ -1074,6 +1075,14 @@ class ServerArgs:
|
|||||||
"and cannot be used at the same time. Please use only one of them."
|
"and cannot be used at the same time. Please use only one of them."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.disaggregation_decode_enable_offload_kvcache
|
||||||
|
and self.disaggregation_mode != "decode"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"The argument disaggregation-decode-enable-offload-kvcache is only supported for decode side."
|
||||||
|
)
|
||||||
|
|
||||||
def _handle_metrics_labels(self):
|
def _handle_metrics_labels(self):
|
||||||
if (
|
if (
|
||||||
not self.tokenizer_metrics_custom_labels_header
|
not self.tokenizer_metrics_custom_labels_header
|
||||||
@@ -2556,6 +2565,11 @@ class ServerArgs:
|
|||||||
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
|
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
|
||||||
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
|
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disaggregation-decode-enable-offload-kvcache",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable async KV cache offloading on decode server (PD mode).",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-reserved-decode-tokens",
|
"--num-reserved-decode-tokens",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
271
test/srt/hicache/test_disaggregation_hicache.py
Normal file
271
test/srt/hicache/test_disaggregation_hicache.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from typing import Dict
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.bench_serving import get_tokenizer
|
||||||
|
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_pd_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DisaggregationHiCacheBase(TestDisaggregationBase):
|
||||||
|
"""Base class for disaggregation with HiCache tests"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||||
|
cls.base_host = parsed_url.hostname
|
||||||
|
base_port = str(parsed_url.port)
|
||||||
|
cls.lb_port = base_port
|
||||||
|
cls.prefill_port = f"{int(base_port) + 100}"
|
||||||
|
cls.decode_port = f"{int(base_port) + 200}"
|
||||||
|
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
||||||
|
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
||||||
|
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
||||||
|
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||||
|
|
||||||
|
cls.tokenizer = get_tokenizer(cls.model)
|
||||||
|
cls.temp_dir = tempfile.mkdtemp()
|
||||||
|
cls.start_prefill()
|
||||||
|
cls.start_decode()
|
||||||
|
|
||||||
|
# Block until both
|
||||||
|
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||||
|
cls.wait_server_ready(cls.decode_url + "/health")
|
||||||
|
|
||||||
|
cls.launch_lb()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_prefill(cls):
|
||||||
|
# Prefill with HiCache enabled
|
||||||
|
prefill_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"prefill",
|
||||||
|
"--tp-size",
|
||||||
|
"1",
|
||||||
|
"--page-size",
|
||||||
|
"64",
|
||||||
|
"--enable-hierarchical-cache",
|
||||||
|
"--hicache-ratio",
|
||||||
|
"1.2",
|
||||||
|
"--hicache-size",
|
||||||
|
"0",
|
||||||
|
"--hicache-write-policy",
|
||||||
|
"write_through",
|
||||||
|
"--hicache-storage-backend",
|
||||||
|
"file",
|
||||||
|
"--hicache-storage-prefetch-policy",
|
||||||
|
"wait_complete",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.8",
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
"mlx5_roce0",
|
||||||
|
"--disaggregation-transfer-backend",
|
||||||
|
"mooncake",
|
||||||
|
]
|
||||||
|
env = {
|
||||||
|
**os.environ,
|
||||||
|
"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir,
|
||||||
|
}
|
||||||
|
cls.process_prefill = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.prefill_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=prefill_args,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_decode(cls):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def gen_prompt(self, token_num: int) -> str:
|
||||||
|
all_available_tokens = list(self.tokenizer.get_vocab().values())
|
||||||
|
selected_tokens = random.choices(all_available_tokens, k=token_num)
|
||||||
|
return self.tokenizer.decode(selected_tokens)
|
||||||
|
|
||||||
|
def send_request(
|
||||||
|
self, prompt: str, max_tokens: int = 100, temperature: float = 0.0
|
||||||
|
) -> Dict:
|
||||||
|
"""Send a generate request and return response"""
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.lb_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": prompt,
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": temperature,
|
||||||
|
"max_new_tokens": max_tokens,
|
||||||
|
"ignore_eos": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
timeout=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
200,
|
||||||
|
f"Request failed: {response.status_code} - {response.text}",
|
||||||
|
)
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
def trigger_offloading_and_flush(self):
|
||||||
|
"""Helper method to trigger offloading and flush cache"""
|
||||||
|
# Trigger offloading
|
||||||
|
self.send_request(self.gen_prompt(1), max_tokens=150)
|
||||||
|
|
||||||
|
# Flush device cache to force remote storage access
|
||||||
|
time.sleep(2)
|
||||||
|
requests.post(self.prefill_url + "/flush_cache")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisaggregationPrefillWithHiCache(DisaggregationHiCacheBase):
|
||||||
|
"""Test disaggregation with HiCache enabled only on Prefill side"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_decode(cls):
|
||||||
|
# Decode without HiCache offload
|
||||||
|
decode_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"decode",
|
||||||
|
"--tp-size",
|
||||||
|
"1",
|
||||||
|
"--page-size",
|
||||||
|
"64",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.8",
|
||||||
|
"--base-gpu-id",
|
||||||
|
"1",
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
"mlx5_roce0",
|
||||||
|
"--disaggregation-transfer-backend",
|
||||||
|
"mooncake",
|
||||||
|
]
|
||||||
|
env = {
|
||||||
|
**os.environ,
|
||||||
|
"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir,
|
||||||
|
}
|
||||||
|
cls.process_decode = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.decode_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=decode_args,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prefill_cache_hit(self):
|
||||||
|
"""Test that prefill cache works with repeated queries"""
|
||||||
|
|
||||||
|
repeated_prompt = self.gen_prompt(800)
|
||||||
|
|
||||||
|
# First request - should miss cache
|
||||||
|
self.send_request(repeated_prompt, max_tokens=100)
|
||||||
|
|
||||||
|
# Flush cache
|
||||||
|
self.trigger_offloading_and_flush()
|
||||||
|
|
||||||
|
# Second request - should hit cache (faster)
|
||||||
|
response2 = self.send_request(repeated_prompt, max_tokens=100)
|
||||||
|
|
||||||
|
# Assert cached tokens cnt
|
||||||
|
self.assertGreater(response2["meta_info"]["cached_tokens"], 700)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDisaggregationDecodeWithHiCache(DisaggregationHiCacheBase):
|
||||||
|
"""Test disaggregation with HiCache enabled on both Prefill and Decode sides"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_decode(cls):
|
||||||
|
# Decode with HiCache offload enabled
|
||||||
|
decode_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"decode",
|
||||||
|
"--tp-size",
|
||||||
|
"1",
|
||||||
|
"--page-size",
|
||||||
|
"64",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.8",
|
||||||
|
"--base-gpu-id",
|
||||||
|
"1",
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
"mlx5_roce0",
|
||||||
|
"--disaggregation-transfer-backend",
|
||||||
|
"mooncake",
|
||||||
|
"--disaggregation-decode-enable-offload-kvcache",
|
||||||
|
"--hicache-ratio",
|
||||||
|
"1.2",
|
||||||
|
"--hicache-size",
|
||||||
|
"0",
|
||||||
|
"--hicache-storage-backend",
|
||||||
|
"file",
|
||||||
|
"--hicache-storage-prefetch-policy",
|
||||||
|
"wait_complete",
|
||||||
|
]
|
||||||
|
env = {
|
||||||
|
**os.environ,
|
||||||
|
"SGLANG_HICACHE_FILE_BACKEND_STORAGE_DIR": cls.temp_dir,
|
||||||
|
}
|
||||||
|
cls.process_decode = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.decode_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=decode_args,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multi_turn_conversation_cache(self):
|
||||||
|
"""Test multi-turn conversation scenario with cache hit improvement"""
|
||||||
|
|
||||||
|
print("=== Multi-turn Conversation Cache Test ===")
|
||||||
|
|
||||||
|
# Turn 1
|
||||||
|
initial_prompt = self.gen_prompt(300)
|
||||||
|
|
||||||
|
response1 = self.send_request(initial_prompt, max_tokens=200, temperature=0.1)
|
||||||
|
current_context = initial_prompt + response1["text"]
|
||||||
|
|
||||||
|
# Turns 2-4: Continue generation based on previous context
|
||||||
|
previous_cached_tokens = 0
|
||||||
|
|
||||||
|
for turn in range(2, 5):
|
||||||
|
print(f"\nTurn {turn}: Continuing from previous context")
|
||||||
|
|
||||||
|
response = self.send_request(
|
||||||
|
current_context, max_tokens=200, temperature=0.1
|
||||||
|
)
|
||||||
|
cached_tokens = response["meta_info"]["cached_tokens"]
|
||||||
|
|
||||||
|
print(f"Turn {turn} cached tokens: {cached_tokens}")
|
||||||
|
print(f"Improvement: {cached_tokens - previous_cached_tokens} tokens")
|
||||||
|
|
||||||
|
# Assert cache improvement
|
||||||
|
self.assertGreater(
|
||||||
|
cached_tokens,
|
||||||
|
previous_cached_tokens,
|
||||||
|
f"Turn {turn} should have more cached tokens than turn {turn-1}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update context and cached tokens for next iteration
|
||||||
|
current_context += response["text"]
|
||||||
|
previous_cached_tokens = cached_tokens
|
||||||
|
|
||||||
|
# Flush prefill cache
|
||||||
|
self.trigger_offloading_and_flush()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user