From a02071a12cb29d91ac7bb376e0a0744cad3cbb69 Mon Sep 17 00:00:00 2001 From: Teng Ma Date: Tue, 9 Sep 2025 02:50:54 +0800 Subject: [PATCH] [Bench] feat: mooncake trace integration (#9839) Signed-off-by: Xuchun Shang Signed-off-by: Teng Ma Co-authored-by: Xuchun Shang --- docs/developer_guide/bench_serving.md | 15 ++ python/sglang/bench_serving.py | 253 ++++++++++++++++++++++++-- 2 files changed, 249 insertions(+), 19 deletions(-) diff --git a/docs/developer_guide/bench_serving.md b/docs/developer_guide/bench_serving.md index 35c9b2b0f..82f7aa2af 100644 --- a/docs/developer_guide/bench_serving.md +++ b/docs/developer_guide/bench_serving.md @@ -305,6 +305,21 @@ python3 -m sglang.bench_serving \ --disable-ignore-eos ``` +9) Evaluating large-scale KVCache sharing with mooncake trace (sglang only): + +```bash +python3 -m sglang.bench_serving \ + --backend sglang \ + --host 127.0.0.1 --port 30000 \ + --model mode-name \ + --dataset-name mooncake \ + --mooncake-slowdown-factor 1.0 \ + --mooncake-num-rounds 1000 \ + --mooncake-workload conversation|mooncake|agent|synthetic + --use-trace-timestamps true \ + --random-output-len 256 +``` + ### Troubleshooting - All requests failed: verify `--backend`, server URL/port, `--model`, and authentication. Check warmup errors printed by the script. diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 8386bb66c..6767e9c2e 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -75,6 +75,7 @@ class RequestFuncInput: lora_name: str image_data: Optional[List[str]] extra_request_body: Dict[str, Any] + timestamp: Optional[float] = None @dataclass @@ -696,6 +697,22 @@ def get_dataset(args, tokenizer): apply_chat_template=args.apply_chat_template, random_sample=True, ) + elif args.dataset_name == "mooncake": + # For mooncake, we don't generate the prompts here. + # We just load the raw trace data. The async generator will handle the rest. + if not args.dataset_path: + local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl") + else: + local_path = args.dataset_path + + if not os.path.exists(local_path): + download_and_cache_file(MOONCAKE_DATASET_URL[args.mooncake_workload], local_path) + + with open(local_path, "r") as f: + all_requests_data = [json.loads(line) for line in f if line.strip()] + + # Limit the number of requests based on --num-prompts + input_requests = all_requests_data[: args.num_prompts] else: raise ValueError(f"Unknown dataset: {args.dataset_name}") return input_requests @@ -750,6 +767,12 @@ class BenchmarkMetrics: SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" +MOONCAKE_DATASET_URL = { + "mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl", + "conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl", + "synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl", + "toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl", +} def download_and_cache_file(url: str, filename: Optional[str] = None): @@ -808,6 +831,80 @@ class DatasetRow: prompt_len: int output_len: int image_data: Optional[List[str]] = None + timestamp: Optional[float] = None + + +async def get_mooncake_request_over_time( + input_requests: List[Dict], + tokenizer: PreTrainedTokenizerBase, + slowdown_factor: float, + num_rounds: int, +) -> AsyncGenerator[DatasetRow, None]: + """ + An async generator that yields requests based on the timestamps in the Mooncake trace file, + with support for multi-round sessions. + """ + if not input_requests: + return + + input_requests.sort(key=lambda r: r["timestamp"]) + + start_time = time.perf_counter() + trace_start_time_ms = input_requests[0]["timestamp"] + + for record in input_requests: + # Calculate when this entire session should start + relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0 + target_arrival_time_s = relative_arrival_time_s * slowdown_factor + + current_elapsed_time_s = time.perf_counter() - start_time + sleep_duration_s = target_arrival_time_s - current_elapsed_time_s + if sleep_duration_s > 0: + await asyncio.sleep(sleep_duration_s) + + # Once the session starts, generate all rounds for it as a burst + # This simulates a user engaging in a multi-turn conversation + + # Base user query constructed from hash_ids + user_query_base = "" + hash_ids = record.get("hash_ids", []) + for hash_id in hash_ids: + user_query_base += f"{hash_id}" + " ".join( + ["hi"] * 128 + ) # Shorter for multi-round + user_query_base += "Tell me a story based on this context." + + output_len_per_round = record.get("output_length", 256) + chat_history = [] + + for i in range(num_rounds): + # Add user query for the current round + chat_history.append( + {"role": "user", "content": f"Round {i+1}: {user_query_base}"} + ) + + # Form the full prompt from history + try: + full_prompt_text = tokenizer.apply_chat_template( + chat_history, tokenize=False, add_generation_prompt=True + ) + except Exception: + full_prompt_text = "\n".join( + [f"{msg['role']}: {msg['content']}" for msg in chat_history] + ) + + prompt_len = len(tokenizer.encode(full_prompt_text)) + + yield DatasetRow( + prompt=full_prompt_text, + prompt_len=prompt_len, + output_len=output_len_per_round, + ) + + # Add a placeholder assistant response for the next round's context + # We use a placeholder because we don't know the real response + placeholder_response = " ".join(["story"] * output_len_per_round) + chat_history.append({"role": "assistant", "content": placeholder_response}) def sample_mmmu_requests( @@ -1359,19 +1456,41 @@ def sample_generated_shared_prefix_requests( async def get_request( input_requests: List[DatasetRow], request_rate: float, + use_trace_timestamps: bool = False, + slowdown_factor: float = 1.0, ) -> AsyncGenerator[DatasetRow, None]: - input_requests = iter(input_requests) - for request in input_requests: - yield request + if use_trace_timestamps: + print( + f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}." + ) + # Sort requests by timestamp for correct replay + input_requests.sort(key=lambda r: r.timestamp) - if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue + start_time = time.perf_counter() + trace_start_time_ms = input_requests[0].timestamp if input_requests else 0 - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. - await asyncio.sleep(interval) + for request in input_requests: + trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0 + target_arrival_time = start_time + (trace_time_s * slowdown_factor) + + sleep_duration = target_arrival_time - time.perf_counter() + if sleep_duration > 0: + await asyncio.sleep(sleep_duration) + + yield request + else: + input_requests_iter = iter(input_requests) + for request in input_requests_iter: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) def calculate_metrics( @@ -1397,7 +1516,7 @@ def calculate_metrics( tokenizer.encode(outputs[i].generated_text, add_special_tokens=False) ) retokenized_output_lens.append(retokenized_output_len) - total_input += input_requests[i].prompt_len + total_input += outputs[i].prompt_len if output_len > 1: tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl @@ -1469,6 +1588,9 @@ async def benchmark( pd_separated: bool = False, flush_cache: bool = False, warmup_requests: int = 1, + use_trace_timestamps: bool = False, + mooncake_slowdown_factor=1.0, + mooncake_num_rounds=1, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -1488,8 +1610,32 @@ async def benchmark( # Warmup print(f"Starting warmup with {warmup_requests} sequences...") - # Use the first request for all warmup iterations - test_request = input_requests[0] + # Handle the data structure difference for the warmup request + if args.dataset_name == "mooncake": + # For mooncake, input_requests is a list of dicts. + # We need to build a temporary DatasetRow for the warmup phase. + warmup_record = input_requests[0] + + # Build prompt from hash_ids, just like in the async generator + hash_ids = warmup_record.get("hash_ids", []) + prompt_text = "" + for hash_id in hash_ids: + prompt_text += f"{hash_id}" + " ".join(["hi"] * 512) + prompt_text += "Can you tell me a detailed story in 1000 words?" + + output_len = warmup_record.get("output_length", 32) + prompt_len = len(tokenizer.encode(prompt_text)) + + # Create a temporary DatasetRow object for warmup + test_request = DatasetRow( + prompt=prompt_text, + prompt_len=prompt_len, + output_len=output_len, + image_data=None, # Mooncake doesn't have image data + ) + else: + # For all other datasets, input_requests is a list of DatasetRow objects + test_request = input_requests[0] if lora_names is not None and len(lora_names) != 0: lora_name = lora_names[0] @@ -1543,12 +1689,26 @@ async def benchmark( if profile_output.success: print("Profiler started") - pbar = None if disable_tqdm else tqdm(total=len(input_requests)) - # Run all requests benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): + pbar_total = len(input_requests) + if ( + backend == "sglang" and args.dataset_name == "mooncake" + ): # Assuming mooncake is mainly for sglang or similar backends + print("Using time-based Mooncake request scheduler, ignoring --request-rate.") + request_generator = get_mooncake_request_over_time( + input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds + ) + print( + f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}" + ) + pbar_total *= args.mooncake_num_rounds + else: + request_generator = get_request(input_requests, request_rate) + + pbar = None if disable_tqdm else tqdm(total=pbar_total) + async for request in request_generator: if lora_names is not None and len(lora_names) != 0: idx = random.randint(0, len(lora_names) - 1) lora_name = lora_names[idx] @@ -1564,6 +1724,7 @@ async def benchmark( lora_name=lora_name, image_data=request.image_data, extra_request_body=extra_request_body, + timestamp=request.timestamp, ) tasks.append( @@ -1609,7 +1770,11 @@ async def benchmark( print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Backend:", backend)) - print("{:<40} {:<10}".format("Traffic request rate:", request_rate)) + print( + "{:<40} {:<10}".format( + "Traffic request rate:", "trace" if use_trace_timestamps else request_rate + ) + ) print( "{:<40} {:<10}".format( "Max request concurrency:", @@ -1678,7 +1843,7 @@ async def benchmark( # Arguments "backend": args.backend, "dataset_name": args.dataset_name, - "request_rate": request_rate, + "request_rate": "trace" if use_trace_timestamps else request_rate, "max_concurrency": max_concurrency, "sharegpt_output_len": args.sharegpt_output_len, "random_input_len": args.random_input_len, @@ -1731,7 +1896,9 @@ async def benchmark( elif args.dataset_name.startswith("random"): output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl" else: - output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl" + output_file_name = ( + f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl" + ) result_details = { "input_lens": [output.prompt_len for output in outputs], @@ -1786,6 +1953,17 @@ def run_benchmark(args_: argparse.Namespace): if not hasattr(args, "tokenize_prompt"): args.tokenize_prompt = False + if not hasattr(args, "use_trace_timestamps"): + args.use_trace_timestamps = False + if not hasattr(args, "mooncake_slowdown_factor"): + args.mooncake_slowdown_factor = 1.0 + + if not hasattr(args, "mooncake_slowdown_factor"): + args.mooncake_slowdown_factor = 1.0 + + if not hasattr(args, "mooncake_num_rounds"): + args.mooncake_num_rounds = 1 + print(f"benchmark_args={args}") # Set global environments @@ -1919,6 +2097,9 @@ def run_benchmark(args_: argparse.Namespace): pd_separated=args.pd_separated, flush_cache=args.flush_cache, warmup_requests=args.warmup_requests, + use_trace_timestamps=args.use_trace_timestamps, + mooncake_slowdown_factor=args.mooncake_slowdown_factor, + mooncake_num_rounds=args.mooncake_num_rounds, ) ) @@ -1975,6 +2156,7 @@ if __name__ == "__main__": "generated-shared-prefix", "mmmu", "random-image", + "mooncake", ], help="Name of the dataset to benchmark on.", ) @@ -2051,6 +2233,11 @@ if __name__ == "__main__": help="Number of requests per second. If this is inf, then all the requests are sent at time 0. " "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.", ) + parser.add_argument( + "--use-trace-timestamps", + action="store_true", + help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.", + ) parser.add_argument( "--max-concurrency", type=int, @@ -2174,5 +2361,33 @@ if __name__ == "__main__": default=256, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) + mooncake_group = parser.add_argument_group("mooncake dataset arguments") + mooncake_group.add_argument( + "--mooncake-slowdown-factor", + type=float, + default=1.0, + help="Slowdown factor for replaying the mooncake trace. " + "A value of 2.0 means the replay is twice as slow. " + "NOTE: --request-rate is IGNORED in mooncake mode.", + ) + mooncake_group.add_argument( + "--mooncake-num-rounds", + type=int, + default=1, + help="Number of conversation rounds for each session in the mooncake dataset. " + "A value > 1 will enable true multi-turn session benchmarking.", + ) + mooncake_group.add_argument( + "--mooncake-workload", + type=str, + default="conversation", + choices=[ + "mooncake", + "conversation", + "synthetic", + "toolagent", + ], + help="Underlying workload for the mooncake dataset.", + ) args = parser.parse_args() run_benchmark(args)