Hierarchical Caching for SGLang (#2693)

Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Zhiqiang Xie
2025-02-23 21:56:30 -08:00
committed by GitHub
parent 4d2a88bdff
commit 6c7a152c5a
7 changed files with 732 additions and 91 deletions

View File

@@ -0,0 +1,25 @@
## Run synthetic multi-turn benchmark
```
# SGLang server with radix cache disabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache
# SGLang server with radix cache on and first-come-first-serve policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs
# The default SGLang server with radix cache on and long-prefix-match policy
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000
# SGLang server with hierarchical radix cache enabled
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache
```
```
python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
```
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
## More benchmarks to be added

View File

@@ -5,6 +5,7 @@ import queue
import random
import threading
import time
from datetime import datetime
from typing import Optional
import aiohttp
@@ -26,9 +27,15 @@ def parse_args():
parser.add_argument(
"--num-clients",
type=int,
default=200,
default=256,
help="Number of concurrent clients",
)
parser.add_argument(
"--max-parallel",
type=int,
default=128,
help="Maximum number of parallel requests",
)
parser.add_argument(
"--request-length",
type=int,
@@ -73,11 +80,17 @@ def parse_args():
help="Server port (default: 30000)",
)
parser.add_argument(
"--model",
"--model-path",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers",
)
parser.add_argument(
"--log-file",
type=str,
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
return parser.parse_args()
@@ -158,6 +171,18 @@ def gen_payload(prompt, output_len):
return payload
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
"""Append the data with a timestamp to the specified JSONL file."""
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
try:
with open(file_path, "a") as file:
file.write(
json.dumps(timestamped_data) + "\n"
) # Write as a single line in JSONL format
except IOError as e:
print(f"Error writing to JSONL file: {e}")
class ReadyQueue:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
@@ -191,12 +216,15 @@ class WorkloadGenerator:
# Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = get_tokenizer(args.model)
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
@@ -235,6 +263,18 @@ class WorkloadGenerator:
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < args.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue
if self.pbar.n == self.pbar.total:
break
# Calculate Poisson-distributed wait time
if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate)
@@ -247,14 +287,6 @@ class WorkloadGenerator:
raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request
new_request = self.ready_queue.pop()
# Submit async request
if new_request:
asyncio.create_task(self.handle_request(new_request))
else:
if self.pbar.n == self.pbar.total:
break
# Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@@ -273,6 +305,7 @@ class WorkloadGenerator:
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds:
self.client_records[client_id][
@@ -301,34 +334,56 @@ class WorkloadGenerator:
request_thread.join()
response_thread.join()
self.pbar.close()
print("All requests completed.")
performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"request_rate": self.request_rate,
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.pbar.total / (self.finished_time - self.start_time),
},
}
print("All requests completed")
print("Performance metrics summary:")
print(
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second"
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
)
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print(
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
)
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print(
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}"
)
print(
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}"
)
throughput = self.pbar.total / (self.finished_time - self.start_time)
print(f"Throughput: {throughput:.2f} requests per second")
log_to_jsonl_file(performance_data, args.log_file)
if __name__ == "__main__":
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in range(1, 41, 2):
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
args.request_rate = request_rate
requests.post(flush_cache_url)
time.sleep(1)
WorkloadGenerator(args).run()