[Bench] feat: mooncake trace integration (#9839)
Signed-off-by: Xuchun Shang <xuchun.shang@linux.alibaba.com> Signed-off-by: Teng Ma <sima.mt@alibaba-inc.com> Co-authored-by: Xuchun Shang <xuchun.shang@linux.alibaba.com>
This commit is contained in:
@@ -305,6 +305,21 @@ python3 -m sglang.bench_serving \
|
||||
--disable-ignore-eos
|
||||
```
|
||||
|
||||
9) Evaluating large-scale KVCache sharing with mooncake trace (sglang only):
|
||||
|
||||
```bash
|
||||
python3 -m sglang.bench_serving \
|
||||
--backend sglang \
|
||||
--host 127.0.0.1 --port 30000 \
|
||||
--model mode-name \
|
||||
--dataset-name mooncake \
|
||||
--mooncake-slowdown-factor 1.0 \
|
||||
--mooncake-num-rounds 1000 \
|
||||
--mooncake-workload conversation|mooncake|agent|synthetic
|
||||
--use-trace-timestamps true \
|
||||
--random-output-len 256
|
||||
```
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
- All requests failed: verify `--backend`, server URL/port, `--model`, and authentication. Check warmup errors printed by the script.
|
||||
|
||||
@@ -75,6 +75,7 @@ class RequestFuncInput:
|
||||
lora_name: str
|
||||
image_data: Optional[List[str]]
|
||||
extra_request_body: Dict[str, Any]
|
||||
timestamp: Optional[float] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -696,6 +697,22 @@ def get_dataset(args, tokenizer):
|
||||
apply_chat_template=args.apply_chat_template,
|
||||
random_sample=True,
|
||||
)
|
||||
elif args.dataset_name == "mooncake":
|
||||
# For mooncake, we don't generate the prompts here.
|
||||
# We just load the raw trace data. The async generator will handle the rest.
|
||||
if not args.dataset_path:
|
||||
local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl")
|
||||
else:
|
||||
local_path = args.dataset_path
|
||||
|
||||
if not os.path.exists(local_path):
|
||||
download_and_cache_file(MOONCAKE_DATASET_URL[args.mooncake_workload], local_path)
|
||||
|
||||
with open(local_path, "r") as f:
|
||||
all_requests_data = [json.loads(line) for line in f if line.strip()]
|
||||
|
||||
# Limit the number of requests based on --num-prompts
|
||||
input_requests = all_requests_data[: args.num_prompts]
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
return input_requests
|
||||
@@ -750,6 +767,12 @@ class BenchmarkMetrics:
|
||||
|
||||
|
||||
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
MOONCAKE_DATASET_URL = {
|
||||
"mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
|
||||
"conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
|
||||
"synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
|
||||
"toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
|
||||
}
|
||||
|
||||
|
||||
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
||||
@@ -808,6 +831,80 @@ class DatasetRow:
|
||||
prompt_len: int
|
||||
output_len: int
|
||||
image_data: Optional[List[str]] = None
|
||||
timestamp: Optional[float] = None
|
||||
|
||||
|
||||
async def get_mooncake_request_over_time(
|
||||
input_requests: List[Dict],
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
slowdown_factor: float,
|
||||
num_rounds: int,
|
||||
) -> AsyncGenerator[DatasetRow, None]:
|
||||
"""
|
||||
An async generator that yields requests based on the timestamps in the Mooncake trace file,
|
||||
with support for multi-round sessions.
|
||||
"""
|
||||
if not input_requests:
|
||||
return
|
||||
|
||||
input_requests.sort(key=lambda r: r["timestamp"])
|
||||
|
||||
start_time = time.perf_counter()
|
||||
trace_start_time_ms = input_requests[0]["timestamp"]
|
||||
|
||||
for record in input_requests:
|
||||
# Calculate when this entire session should start
|
||||
relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
|
||||
target_arrival_time_s = relative_arrival_time_s * slowdown_factor
|
||||
|
||||
current_elapsed_time_s = time.perf_counter() - start_time
|
||||
sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
|
||||
if sleep_duration_s > 0:
|
||||
await asyncio.sleep(sleep_duration_s)
|
||||
|
||||
# Once the session starts, generate all rounds for it as a burst
|
||||
# This simulates a user engaging in a multi-turn conversation
|
||||
|
||||
# Base user query constructed from hash_ids
|
||||
user_query_base = ""
|
||||
hash_ids = record.get("hash_ids", [])
|
||||
for hash_id in hash_ids:
|
||||
user_query_base += f"{hash_id}" + " ".join(
|
||||
["hi"] * 128
|
||||
) # Shorter for multi-round
|
||||
user_query_base += "Tell me a story based on this context."
|
||||
|
||||
output_len_per_round = record.get("output_length", 256)
|
||||
chat_history = []
|
||||
|
||||
for i in range(num_rounds):
|
||||
# Add user query for the current round
|
||||
chat_history.append(
|
||||
{"role": "user", "content": f"Round {i+1}: {user_query_base}"}
|
||||
)
|
||||
|
||||
# Form the full prompt from history
|
||||
try:
|
||||
full_prompt_text = tokenizer.apply_chat_template(
|
||||
chat_history, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
except Exception:
|
||||
full_prompt_text = "\n".join(
|
||||
[f"{msg['role']}: {msg['content']}" for msg in chat_history]
|
||||
)
|
||||
|
||||
prompt_len = len(tokenizer.encode(full_prompt_text))
|
||||
|
||||
yield DatasetRow(
|
||||
prompt=full_prompt_text,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len_per_round,
|
||||
)
|
||||
|
||||
# Add a placeholder assistant response for the next round's context
|
||||
# We use a placeholder because we don't know the real response
|
||||
placeholder_response = " ".join(["story"] * output_len_per_round)
|
||||
chat_history.append({"role": "assistant", "content": placeholder_response})
|
||||
|
||||
|
||||
def sample_mmmu_requests(
|
||||
@@ -1359,19 +1456,41 @@ def sample_generated_shared_prefix_requests(
|
||||
async def get_request(
|
||||
input_requests: List[DatasetRow],
|
||||
request_rate: float,
|
||||
use_trace_timestamps: bool = False,
|
||||
slowdown_factor: float = 1.0,
|
||||
) -> AsyncGenerator[DatasetRow, None]:
|
||||
input_requests = iter(input_requests)
|
||||
for request in input_requests:
|
||||
yield request
|
||||
if use_trace_timestamps:
|
||||
print(
|
||||
f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
|
||||
)
|
||||
# Sort requests by timestamp for correct replay
|
||||
input_requests.sort(key=lambda r: r.timestamp)
|
||||
|
||||
if request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
start_time = time.perf_counter()
|
||||
trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
|
||||
|
||||
# Sample the request interval from the exponential distribution.
|
||||
interval = np.random.exponential(1.0 / request_rate)
|
||||
# The next request will be sent after the interval.
|
||||
await asyncio.sleep(interval)
|
||||
for request in input_requests:
|
||||
trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
|
||||
target_arrival_time = start_time + (trace_time_s * slowdown_factor)
|
||||
|
||||
sleep_duration = target_arrival_time - time.perf_counter()
|
||||
if sleep_duration > 0:
|
||||
await asyncio.sleep(sleep_duration)
|
||||
|
||||
yield request
|
||||
else:
|
||||
input_requests_iter = iter(input_requests)
|
||||
for request in input_requests_iter:
|
||||
yield request
|
||||
|
||||
if request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
|
||||
# Sample the request interval from the exponential distribution.
|
||||
interval = np.random.exponential(1.0 / request_rate)
|
||||
# The next request will be sent after the interval.
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
@@ -1397,7 +1516,7 @@ def calculate_metrics(
|
||||
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
||||
)
|
||||
retokenized_output_lens.append(retokenized_output_len)
|
||||
total_input += input_requests[i].prompt_len
|
||||
total_input += outputs[i].prompt_len
|
||||
if output_len > 1:
|
||||
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
||||
itls += outputs[i].itl
|
||||
@@ -1469,6 +1588,9 @@ async def benchmark(
|
||||
pd_separated: bool = False,
|
||||
flush_cache: bool = False,
|
||||
warmup_requests: int = 1,
|
||||
use_trace_timestamps: bool = False,
|
||||
mooncake_slowdown_factor=1.0,
|
||||
mooncake_num_rounds=1,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@@ -1488,8 +1610,32 @@ async def benchmark(
|
||||
# Warmup
|
||||
print(f"Starting warmup with {warmup_requests} sequences...")
|
||||
|
||||
# Use the first request for all warmup iterations
|
||||
test_request = input_requests[0]
|
||||
# Handle the data structure difference for the warmup request
|
||||
if args.dataset_name == "mooncake":
|
||||
# For mooncake, input_requests is a list of dicts.
|
||||
# We need to build a temporary DatasetRow for the warmup phase.
|
||||
warmup_record = input_requests[0]
|
||||
|
||||
# Build prompt from hash_ids, just like in the async generator
|
||||
hash_ids = warmup_record.get("hash_ids", [])
|
||||
prompt_text = ""
|
||||
for hash_id in hash_ids:
|
||||
prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
|
||||
prompt_text += "Can you tell me a detailed story in 1000 words?"
|
||||
|
||||
output_len = warmup_record.get("output_length", 32)
|
||||
prompt_len = len(tokenizer.encode(prompt_text))
|
||||
|
||||
# Create a temporary DatasetRow object for warmup
|
||||
test_request = DatasetRow(
|
||||
prompt=prompt_text,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
image_data=None, # Mooncake doesn't have image data
|
||||
)
|
||||
else:
|
||||
# For all other datasets, input_requests is a list of DatasetRow objects
|
||||
test_request = input_requests[0]
|
||||
|
||||
if lora_names is not None and len(lora_names) != 0:
|
||||
lora_name = lora_names[0]
|
||||
@@ -1543,12 +1689,26 @@ async def benchmark(
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
||||
|
||||
# Run all requests
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: List[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate):
|
||||
pbar_total = len(input_requests)
|
||||
if (
|
||||
backend == "sglang" and args.dataset_name == "mooncake"
|
||||
): # Assuming mooncake is mainly for sglang or similar backends
|
||||
print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
|
||||
request_generator = get_mooncake_request_over_time(
|
||||
input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
|
||||
)
|
||||
print(
|
||||
f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
|
||||
)
|
||||
pbar_total *= args.mooncake_num_rounds
|
||||
else:
|
||||
request_generator = get_request(input_requests, request_rate)
|
||||
|
||||
pbar = None if disable_tqdm else tqdm(total=pbar_total)
|
||||
async for request in request_generator:
|
||||
if lora_names is not None and len(lora_names) != 0:
|
||||
idx = random.randint(0, len(lora_names) - 1)
|
||||
lora_name = lora_names[idx]
|
||||
@@ -1564,6 +1724,7 @@ async def benchmark(
|
||||
lora_name=lora_name,
|
||||
image_data=request.image_data,
|
||||
extra_request_body=extra_request_body,
|
||||
timestamp=request.timestamp,
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
@@ -1609,7 +1770,11 @@ async def benchmark(
|
||||
|
||||
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
||||
print("{:<40} {:<10}".format("Backend:", backend))
|
||||
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
||||
print(
|
||||
"{:<40} {:<10}".format(
|
||||
"Traffic request rate:", "trace" if use_trace_timestamps else request_rate
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10}".format(
|
||||
"Max request concurrency:",
|
||||
@@ -1678,7 +1843,7 @@ async def benchmark(
|
||||
# Arguments
|
||||
"backend": args.backend,
|
||||
"dataset_name": args.dataset_name,
|
||||
"request_rate": request_rate,
|
||||
"request_rate": "trace" if use_trace_timestamps else request_rate,
|
||||
"max_concurrency": max_concurrency,
|
||||
"sharegpt_output_len": args.sharegpt_output_len,
|
||||
"random_input_len": args.random_input_len,
|
||||
@@ -1731,7 +1896,9 @@ async def benchmark(
|
||||
elif args.dataset_name.startswith("random"):
|
||||
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
||||
else:
|
||||
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
||||
output_file_name = (
|
||||
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
|
||||
)
|
||||
|
||||
result_details = {
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@@ -1786,6 +1953,17 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if not hasattr(args, "tokenize_prompt"):
|
||||
args.tokenize_prompt = False
|
||||
|
||||
if not hasattr(args, "use_trace_timestamps"):
|
||||
args.use_trace_timestamps = False
|
||||
if not hasattr(args, "mooncake_slowdown_factor"):
|
||||
args.mooncake_slowdown_factor = 1.0
|
||||
|
||||
if not hasattr(args, "mooncake_slowdown_factor"):
|
||||
args.mooncake_slowdown_factor = 1.0
|
||||
|
||||
if not hasattr(args, "mooncake_num_rounds"):
|
||||
args.mooncake_num_rounds = 1
|
||||
|
||||
print(f"benchmark_args={args}")
|
||||
|
||||
# Set global environments
|
||||
@@ -1919,6 +2097,9 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
pd_separated=args.pd_separated,
|
||||
flush_cache=args.flush_cache,
|
||||
warmup_requests=args.warmup_requests,
|
||||
use_trace_timestamps=args.use_trace_timestamps,
|
||||
mooncake_slowdown_factor=args.mooncake_slowdown_factor,
|
||||
mooncake_num_rounds=args.mooncake_num_rounds,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1975,6 +2156,7 @@ if __name__ == "__main__":
|
||||
"generated-shared-prefix",
|
||||
"mmmu",
|
||||
"random-image",
|
||||
"mooncake",
|
||||
],
|
||||
help="Name of the dataset to benchmark on.",
|
||||
)
|
||||
@@ -2051,6 +2233,11 @@ if __name__ == "__main__":
|
||||
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use-trace-timestamps",
|
||||
action="store_true",
|
||||
help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
@@ -2174,5 +2361,33 @@ if __name__ == "__main__":
|
||||
default=256,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
mooncake_group = parser.add_argument_group("mooncake dataset arguments")
|
||||
mooncake_group.add_argument(
|
||||
"--mooncake-slowdown-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Slowdown factor for replaying the mooncake trace. "
|
||||
"A value of 2.0 means the replay is twice as slow. "
|
||||
"NOTE: --request-rate is IGNORED in mooncake mode.",
|
||||
)
|
||||
mooncake_group.add_argument(
|
||||
"--mooncake-num-rounds",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of conversation rounds for each session in the mooncake dataset. "
|
||||
"A value > 1 will enable true multi-turn session benchmarking.",
|
||||
)
|
||||
mooncake_group.add_argument(
|
||||
"--mooncake-workload",
|
||||
type=str,
|
||||
default="conversation",
|
||||
choices=[
|
||||
"mooncake",
|
||||
"conversation",
|
||||
"synthetic",
|
||||
"toolagent",
|
||||
],
|
||||
help="Underlying workload for the mooncake dataset.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
|
||||
Reference in New Issue
Block a user