diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index 311632525..287ce52bd 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -20,6 +20,8 @@ from sglang.bench_serving import ( sample_random_requests, ) +AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60) + def parse_args(): parser = argparse.ArgumentParser( @@ -139,7 +141,7 @@ async def async_request_sglang_generate( """ Sends a streaming request to the server. Gathers text token-by-token. """ - async with aiohttp.ClientSession() as session: + async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: headers = {} generated_text = "" ttft = 0.0 @@ -150,6 +152,8 @@ async def async_request_sglang_generate( try: async with session.post(url=url, json=payload, headers=headers) as response: if response.status == 200: + prompt_tokens = 0 + cached_tokens = 0 async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: @@ -168,6 +172,12 @@ async def async_request_sglang_generate( if ttft == 0.0: ttft = time.perf_counter() - st output.ttft = ttft + prompt_tokens = (data.get("meta_info") or {}).get( + "prompt_tokens", 0 + ) + cached_tokens = (data.get("meta_info") or {}).get( + "cached_tokens", 0 + ) # Decoding phase else: @@ -179,6 +189,8 @@ async def async_request_sglang_generate( output.generated_text = generated_text output.success = True output.latency = latency + output.prompt_len = prompt_tokens + output.cached_tokens = cached_tokens else: output.error = response.reason or "" output.success = False @@ -201,6 +213,7 @@ def gen_payload(prompt, output_len): "ignore_eos": True, }, "stream": True, + "stream_options": {"include_usage": True}, "lora_path": "", "return_logprob": False, "logprob_start_len": -1, @@ -303,7 +316,12 @@ class WorkloadGenerator: self.response_queue = queue.Queue() self.pbar = tqdm(total=args.num_clients * args.num_rounds) - self.performance_metrics = {"ttft": [], "latency": []} + self.performance_metrics = { + "ttft": [], + "latency": [], + "prompt_len": [], + "cached_tokens": [], + } async def handle_request(self, item): try: @@ -360,6 +378,8 @@ class WorkloadGenerator: self.client_records[client_id]["round"] += 1 self.performance_metrics["ttft"].append(response.ttft) self.performance_metrics["latency"].append(response.latency) + self.performance_metrics["prompt_len"].append(response.prompt_len) + self.performance_metrics["cached_tokens"].append(response.cached_tokens) self.completed_requests += 1 if self.client_records[client_id]["round"] < args.num_rounds: @@ -416,6 +436,12 @@ class WorkloadGenerator: len(self.performance_metrics["latency"]) // 2 ], "throughput": self.pbar.total / (self.finished_time - self.start_time), + "cache_hit_rate": ( + 0 + if sum(self.performance_metrics["prompt_len"]) == 0 + else sum(self.performance_metrics["cached_tokens"]) + / sum(self.performance_metrics["prompt_len"]) + ), }, } print("All requests completed") @@ -434,6 +460,7 @@ class WorkloadGenerator: print( f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" ) + print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}") log_to_jsonl_file(performance_data, args.log_file, tag=args.tag) diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py index 7e572dcbc..b518f42c5 100644 --- a/python/sglang/srt/managers/cache_controller.py +++ b/python/sglang/srt/managers/cache_controller.py @@ -16,6 +16,7 @@ limitations under the License. import logging import math import threading +import time from queue import Empty, Full, PriorityQueue, Queue from typing import TYPE_CHECKING, List, Optional @@ -195,6 +196,8 @@ class PrefetchOperation(StorageOperation): self._done_flag = False self._lock = threading.Lock() + self.start_time = time.monotonic() + super().__init__(host_indices, token_ids, last_hash) def increment(self, num_tokens: int): @@ -278,6 +281,12 @@ class HiCacheController: self.enable_storage = True # todo: threshold policy for prefetching self.prefetch_threshold = max(prefetch_threshold, self.page_size) + self.prefetch_capacity_limit = int( + 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size) + ) + # tracking the number of tokens locked in prefetching, updated by the main scheduler thread + self.prefetch_tokens_occupied = 0 + # create a new communication group for synchronizing storage operations across TP workers self.tp_world_size = torch.distributed.get_world_size(group=tp_group) if self.tp_world_size > 1: @@ -525,7 +534,7 @@ class HiCacheController: host_indices: torch.Tensor, new_input_tokens: List[int], last_hash: Optional[str] = None, - ) -> int: + ) -> PrefetchOperation: """ Prefetch KV caches from storage backend to host memory. """ @@ -586,11 +595,23 @@ class HiCacheController: operation = self.prefetch_buffer.get(block=True, timeout=1) if self.is_mooncake_backend(): self.mooncake_page_transfer(operation) + elif self.storage_backend_type == "hf3fs": + self.generic_page_transfer(operation, batch_size=128) else: self.generic_page_transfer(operation) except Empty: continue + def prefetch_rate_limit_check(self) -> bool: + """ + Rate limit the prefetching operations to avoid overwhelming the storage backend. + """ + # cancel prefetch if too much memory is occupied + if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit: + return False + # todo: more sophisticated rate limiting based on storage backend performance + return True + def prefetch_thread_func(self): """ Manage prefetching operations from storage backend to host memory. @@ -604,34 +625,36 @@ class HiCacheController: if operation is None: continue - last_hash = operation.last_hash - tokens_to_fetch = operation.token_ids - storage_hit_count = 0 - remaining_tokens = len(tokens_to_fetch) - hash_value = [] - while remaining_tokens >= self.page_size: - last_hash = self.get_hash_str( - tokens_to_fetch[ - storage_hit_count : storage_hit_count + self.page_size - ], - last_hash, - ) + if self.prefetch_rate_limit_check(): + last_hash = operation.last_hash + tokens_to_fetch = operation.token_ids - # todo, more unified interface - if not self.is_mooncake_backend(): - if not self.storage_backend.exists(last_hash): - break - hash_value.append(last_hash) - storage_hit_count += self.page_size - remaining_tokens -= self.page_size + remaining_tokens = len(tokens_to_fetch) + hash_value = [] + while remaining_tokens >= self.page_size: + last_hash = self.get_hash_str( + tokens_to_fetch[ + storage_hit_count : storage_hit_count + self.page_size + ], + last_hash, + ) - if self.is_mooncake_backend(): - # deferring to batch exists for mooncake store - exist_result = self.storage_backend.exists(hash_value) - storage_hit_count = ( - sum(1 for v in exist_result.values() if v != 0) * self.page_size - ) + # todo, more unified interface + if not self.is_mooncake_backend(): + if not self.storage_backend.exists(last_hash): + break + hash_value.append(last_hash) + storage_hit_count += self.page_size + remaining_tokens -= self.page_size + + if self.is_mooncake_backend(): + # deferring to batch exists for mooncake store + exist_result = self.storage_backend.exists(hash_value) + storage_hit_count = ( + sum(1 for v in exist_result.values() if v != 0) + * self.page_size + ) if self.tp_world_size > 1: storage_hit_count_tensor = torch.tensor( @@ -750,6 +773,8 @@ class HiCacheController: if self.is_mooncake_backend(): self.mooncake_page_backup(operation) + elif self.storage_backend_type == "hf3fs": + self.generic_page_backup(operation, batch_size=128) else: self.generic_page_backup(operation) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index e8b1ae6fc..5aef261c1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -619,6 +619,7 @@ class Scheduler( ), hicache_mem_layout=server_args.hicache_mem_layout, hicache_storage_backend=server_args.hicache_storage_backend, + hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, ) self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter @@ -1572,7 +1573,10 @@ class Scheduler( break if self.enable_hicache_storage: - self.tree_cache.check_prefetch_progress(req.rid) + prefetch_done = self.tree_cache.check_prefetch_progress(req.rid) + if not prefetch_done: + # skip staging requests that are ongoing prefetch + continue req.init_next_round_input(self.tree_cache) res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 7b26fa8a7..0f51712eb 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -2,11 +2,12 @@ import heapq import logging import threading import time +from queue import Queue from typing import List, Optional import torch -from sglang.srt.managers.cache_controller import HiCacheController +from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.memory_pool import ( @@ -37,6 +38,7 @@ class HiRadixCache(RadixCache): hicache_io_backend: str, hicache_mem_layout: str, hicache_storage_backend: Optional[str] = None, + hicache_storage_prefetch_policy: Optional[str] = "best_effort", ): if hicache_io_backend == "direct": @@ -85,6 +87,13 @@ class HiRadixCache(RadixCache): prefetch_threshold=self.prefetch_threshold, ) + self.prefetch_stop_policy = hicache_storage_prefetch_policy + # todo: customizable storage prefetch timeout + self.prefetch_timeout = 3 # seconds + logger.info( + f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}" + ) + # record the nodes with ongoing write through self.ongoing_write_through = {} # record the node segments with ongoing load back @@ -385,9 +394,10 @@ class HiRadixCache(RadixCache): for _ in range(queue_size.item()): req_id = self.cache_controller.prefetch_revoke_queue.get() if req_id in self.ongoing_prefetch: - last_host_node, _, _, _ = self.ongoing_prefetch[req_id] + last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id] last_host_node.release_host() del self.ongoing_prefetch[req_id] + self.cache_controller.prefetch_tokens_occupied -= len(token_ids) else: # the revoked operation already got terminated pass @@ -419,10 +429,41 @@ class HiRadixCache(RadixCache): host_node.release_host() del self.ongoing_backup[ack_id] - def check_prefetch_progress(self, req_id: str): + def can_terminate_prefetch(self, operation: PrefetchOperation): + can_terminate = True + + if self.prefetch_stop_policy == "best_effort": + return can_terminate + + completed = ( + operation.completed_tokens == len(operation.hash_value) * self.page_size + ) + + if self.prefetch_stop_policy == "wait_complete": + can_terminate = completed + elif self.prefetch_stop_policy == "timeout": + can_terminate = completed or ( + time.monotonic() - operation.start_time > self.prefetch_timeout + ) + else: + # unknown prefetch stop policy, just return True + return True + + if self.tp_world_size > 1: + can_terminate = torch.tensor(can_terminate, dtype=torch.int) + torch.distributed.all_reduce( + can_terminate, + op=torch.distributed.ReduceOp.MIN, + group=self.tp_group, + ) + can_terminate = bool(can_terminate.item()) + + return can_terminate + + def check_prefetch_progress(self, req_id: str) -> bool: if req_id not in self.ongoing_prefetch: # there is no ongoing prefetch for this request or it has been revoked - return + return True # todo: more policies for prefetch progress such as timeout # the current policy is to prefetch with best effort and terminate when queuing is over @@ -430,13 +471,16 @@ class HiRadixCache(RadixCache): req_id ] + if not self.can_terminate_prefetch(operation): + return False + completed_tokens, hash_value = self.cache_controller.terminate_prefetch( operation ) logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") min_completed_tokens = completed_tokens - if self.tp_world_size > 1: + if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete": # synchrnoize TP workers to make the same update to hiradix cache completed_tokens_tensor = torch.tensor( min_completed_tokens, dtype=torch.int @@ -464,6 +508,9 @@ class HiRadixCache(RadixCache): ) last_host_node.release_host() del self.ongoing_prefetch[req_id] + self.cache_controller.prefetch_tokens_occupied -= len(token_ids) + + return True def match_prefix(self, key: List[int], **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) @@ -531,6 +578,7 @@ class HiRadixCache(RadixCache): host_indices, operation, ) + self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens) def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value): node.last_access_time = time.monotonic() diff --git a/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py b/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py index e38facf3c..399a90118 100644 --- a/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +++ b/python/sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py @@ -96,6 +96,8 @@ class Hf3fsClient: ) self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point) self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point) + self.shm_r.unlink() + self.shm_w.unlink() self.rlock = threading.RLock() self.wlock = threading.RLock() @@ -176,8 +178,6 @@ class Hf3fsClient: del self.iov_w self.shm_r.close() self.shm_w.close() - self.shm_r.unlink() - self.shm_w.unlink() def flush(self) -> None: os.fsync(self.file) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a7f5a4c0f..e6d2f9c57 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -203,6 +203,7 @@ class ServerArgs: hicache_io_backend: str = "kernel" hicache_mem_layout: str = "layer_first" hicache_storage_backend: Optional[str] = None + hicache_storage_prefetch_policy: str = "best_effort" # Double Sparsity enable_double_sparsity: bool = False @@ -1626,6 +1627,13 @@ class ServerArgs: default=ServerArgs.hicache_storage_backend, help="The storage backend for hierarchical KV cache.", ) + parser.add_argument( + "--hicache-storage-prefetch-policy", + type=str, + choices=["best_effort", "wait_complete", "timeout"], + default=ServerArgs.hicache_storage_prefetch_policy, + help="Control when prefetching from the storage backend should stop.", + ) # Double Sparsity parser.add_argument(