diff --git a/benchmark/hicache/bench_long_context.py b/benchmark/hicache/bench_long_context.py new file mode 100644 index 000000000..dc153b8a9 --- /dev/null +++ b/benchmark/hicache/bench_long_context.py @@ -0,0 +1,96 @@ +import json +import queue +import time + +import requests +from bench_multiturn import ( + ReadyQueue, + WorkloadGenerator, + gen_payload, + log_to_jsonl_file, + parse_args, +) +from tqdm.asyncio import tqdm + +from sglang.bench_serving import get_tokenizer + + +class ContextWorkloadGenerator(WorkloadGenerator): + def __init__(self, args): + # Construct the base URL for requests + self.baseurl = f"http://{args.host}:{args.port}/" + self.url = self.baseurl + "generate" + + self.tokenizer = get_tokenizer(args.model_path) + self.distribution = args.distribution + self.request_rate = args.request_rate + self.start_time = None + self.finished_time = None + + self.sent_requests = 0 + self.completed_requests = 0 + + self.dataset = json.load(open(args.dataset_path)) + + init_requests = [] + for i in range(min(args.num_clients, len(self.dataset["queries"]))): + context_id = self.dataset["queries"][i]["context"] + init_requests.append( + ( + i, + gen_payload( + self.dataset["contexts"][context_id] + + self.dataset["queries"][i]["question"], + len( + self.tokenizer( + self.dataset["queries"][i]["reference_answer"] + )["input_ids"] + ), + ), + ) + ) + self.ready_queue = ReadyQueue(init_requests=init_requests) + + self.response_queue = queue.Queue() + self.pbar = tqdm(total=args.num_clients * args.num_rounds) + self.performance_metrics = { + "ttft": [], + "latency": [], + "itl": [], + "prompt_len": [], + "cached_tokens": [], + } + + self.max_parallel = args.max_parallel + self.logfile = args.log_file + + def response_handler(self): + while True: + try: + client_id, response = self.response_queue.get( + timeout=10 + ) # Block until response is available + if not response.success: + raise ValueError(f"Request failed with error: {response.error}") + self.performance_metrics["ttft"].append(response.ttft) + self.performance_metrics["itl"].extend(response.itl) + self.performance_metrics["latency"].append(response.latency) + self.completed_requests += 1 + + except queue.Empty: + if self.pbar.n == self.pbar.total: + break + + +if __name__ == "__main__": + args = parse_args() + args.num_rounds = 1 + args.max_parallel = 128 + flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" + + for request_rate in [24, 16, 12, 8, 4, 2, 1]: + args.request_rate = request_rate + requests.post(flush_cache_url) + time.sleep(1) + performance_data = ContextWorkloadGenerator(args).run() + log_to_jsonl_file(performance_data, args.log_file, args.tag) diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py index 287ce52bd..35e638d33 100644 --- a/benchmark/hicache/bench_multiturn.py +++ b/benchmark/hicache/bench_multiturn.py @@ -322,6 +322,9 @@ class WorkloadGenerator: "prompt_len": [], "cached_tokens": [], } + self.num_rounds = args.num_rounds + self.max_parallel = args.max_parallel + self.output_length = args.output_length async def handle_request(self, item): try: @@ -336,7 +339,7 @@ class WorkloadGenerator: def request_sender(self): async def request_loop(): while True: - if self.sent_requests - self.completed_requests < args.max_parallel: + if self.sent_requests - self.completed_requests < self.max_parallel: new_request = self.ready_queue.pop() if new_request: asyncio.create_task(self.handle_request(new_request)) @@ -382,7 +385,7 @@ class WorkloadGenerator: self.performance_metrics["cached_tokens"].append(response.cached_tokens) self.completed_requests += 1 - if self.client_records[client_id]["round"] < args.num_rounds: + if self.client_records[client_id]["round"] < self.num_rounds: # append new request to client's history self.client_records[client_id][ "history" @@ -392,7 +395,7 @@ class WorkloadGenerator: client_id, gen_payload( self.client_records[client_id]["history"], - args.output_length, + self.output_length, ), ) ) @@ -461,7 +464,7 @@ class WorkloadGenerator: f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second" ) print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}") - log_to_jsonl_file(performance_data, args.log_file, tag=args.tag) + return performance_data if __name__ == "__main__": @@ -482,4 +485,5 @@ if __name__ == "__main__": args.request_rate = rate requests.post(flush_cache_url) time.sleep(1) - WorkloadGenerator(args).run() + performance_data = WorkloadGenerator(args).run() + log_to_jsonl_file(performance_data, args.log_file, tag=args.tag) diff --git a/docs/advanced_features/hyperparameter_tuning.md b/docs/advanced_features/hyperparameter_tuning.md index a80e85ba0..e15ddd21c 100644 --- a/docs/advanced_features/hyperparameter_tuning.md +++ b/docs/advanced_features/hyperparameter_tuning.md @@ -44,9 +44,9 @@ Look for log entries like this: [2025-08-11 17:17:03] max_total_num_tokens=665690, chunked_prefill_size=8192, max_prefill_tokens=16384, max_running_requests=4096, context_len=65536, available_gpu_mem=13.50 GB ``` -Check the `available_gpu_mem` value. -- If it is between 5–8 GB, the setting is good. -- If it is too high (e.g., 10 - 20 GB), increase `--mem-fraction-static` to allocate more memory to the KV cache. +Check the `available_gpu_mem` value. +- If it is between 5–8 GB, the setting is good. +- If it is too high (e.g., 10 - 20 GB), increase `--mem-fraction-static` to allocate more memory to the KV cache. - If it is too low, you risk out-of-memory (OOM) errors later, so decrease `--mem-fraction-static`. Another straightforward approach is to increase `--mem-fraction-static` in increments of 0.01 until you encounter OOM errors for your workloads. diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index 342ca7dd2..d4ff703ba 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -71,8 +71,10 @@ class HiRadixCache(RadixCache): self.tp_group = tp_cache_group self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.enable_storage = hicache_storage_backend is not None - # todo: customizable storage prefetch threshold + # todo: customizable storage prefetch threshold and timeout self.prefetch_threshold = 256 + self.prefetch_timeout = 3 # seconds + self.prefetch_stop_policy = hicache_storage_prefetch_policy self.load_cache_event = threading.Event() self.cache_controller = HiCacheController( @@ -87,13 +89,6 @@ class HiRadixCache(RadixCache): prefetch_threshold=self.prefetch_threshold, ) - self.prefetch_stop_policy = hicache_storage_prefetch_policy - # todo: customizable storage prefetch timeout - self.prefetch_timeout = 3 # seconds - logger.info( - f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}" - ) - # record the nodes with ongoing write through self.ongoing_write_through = {} # record the node segments with ongoing load back