Hierarchical Caching for SGLang (#2693)
Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
25
benchmark/hicache/README.md
Normal file
25
benchmark/hicache/README.md
Normal file
@@ -0,0 +1,25 @@
|
||||
## Run synthetic multi-turn benchmark
|
||||
|
||||
```
|
||||
# SGLang server with radix cache disabled
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --disable-radix-cache
|
||||
|
||||
# SGLang server with radix cache on and first-come-first-serve policy
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --schedule-policy fcfs
|
||||
|
||||
# The default SGLang server with radix cache on and long-prefix-match policy
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000
|
||||
|
||||
# SGLang server with hierarchical radix cache enabled
|
||||
python -m sglang.launch_server --model-path Qwen/Qwen2.5-14B-Instruct --port 30000 --enable-hierarchical-cache
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
python bench_multiturn.py --model-path Qwen/Qwen2.5-14B-Instruct
|
||||
```
|
||||
|
||||
Note: The performance gain of hierarchical caching depends on the ratio of reusable tokens to GPU memory capacity. The more tokens to be reused, the larger the model, and the more constrained the GPU memory size, the greater the benefit one can expect from hierarchical caching.
|
||||
|
||||
|
||||
## More benchmarks to be added
|
||||
@@ -5,6 +5,7 @@ import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -26,9 +27,15 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--num-clients",
|
||||
type=int,
|
||||
default=200,
|
||||
default=256,
|
||||
help="Number of concurrent clients",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Maximum number of parallel requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-length",
|
||||
type=int,
|
||||
@@ -73,11 +80,17 @@ def parse_args():
|
||||
help="Server port (default: 30000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
help="model path compatible with Hugging Face Transformers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
type=str,
|
||||
default="performance_metrics.jsonl",
|
||||
help="File to log performance metrics",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -158,6 +171,18 @@ def gen_payload(prompt, output_len):
|
||||
return payload
|
||||
|
||||
|
||||
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
|
||||
"""Append the data with a timestamp to the specified JSONL file."""
|
||||
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
|
||||
try:
|
||||
with open(file_path, "a") as file:
|
||||
file.write(
|
||||
json.dumps(timestamped_data) + "\n"
|
||||
) # Write as a single line in JSONL format
|
||||
except IOError as e:
|
||||
print(f"Error writing to JSONL file: {e}")
|
||||
|
||||
|
||||
class ReadyQueue:
|
||||
"""
|
||||
Thread-safe queue that can pop requests in different orders based on given policy.
|
||||
@@ -191,12 +216,15 @@ class WorkloadGenerator:
|
||||
# Construct the base URL for requests
|
||||
self.url = f"http://{args.host}:{args.port}/generate"
|
||||
|
||||
self.tokenizer = get_tokenizer(args.model)
|
||||
self.tokenizer = get_tokenizer(args.model_path)
|
||||
self.distribution = args.distribution
|
||||
self.request_rate = args.request_rate
|
||||
self.start_time = None
|
||||
self.finished_time = None
|
||||
|
||||
self.sent_requests = 0
|
||||
self.completed_requests = 0
|
||||
|
||||
self.candidate_inputs = sample_random_requests(
|
||||
input_len=args.request_length,
|
||||
output_len=args.output_length,
|
||||
@@ -235,6 +263,18 @@ class WorkloadGenerator:
|
||||
def request_sender(self):
|
||||
async def request_loop():
|
||||
while True:
|
||||
if self.sent_requests - self.completed_requests < args.max_parallel:
|
||||
new_request = self.ready_queue.pop()
|
||||
if new_request:
|
||||
asyncio.create_task(self.handle_request(new_request))
|
||||
self.sent_requests += 1
|
||||
else:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
# Calculate Poisson-distributed wait time
|
||||
if self.distribution == "poisson":
|
||||
sleep_time = random.expovariate(self.request_rate)
|
||||
@@ -247,14 +287,6 @@ class WorkloadGenerator:
|
||||
raise ValueError("Invalid distribution type")
|
||||
await asyncio.sleep(sleep_time) # Wait before sending the next request
|
||||
|
||||
new_request = self.ready_queue.pop()
|
||||
# Submit async request
|
||||
if new_request:
|
||||
asyncio.create_task(self.handle_request(new_request))
|
||||
else:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
# Create and run the event loop for asynchronous requests
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
@@ -273,6 +305,7 @@ class WorkloadGenerator:
|
||||
self.client_records[client_id]["round"] += 1
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
self.completed_requests += 1
|
||||
|
||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||
self.client_records[client_id][
|
||||
@@ -301,34 +334,56 @@ class WorkloadGenerator:
|
||||
|
||||
request_thread.join()
|
||||
response_thread.join()
|
||||
|
||||
self.pbar.close()
|
||||
print("All requests completed.")
|
||||
|
||||
performance_data = {
|
||||
"summary": {
|
||||
"total_requests": len(self.performance_metrics["ttft"]),
|
||||
"request_rate": self.request_rate,
|
||||
"average_ttft": sum(self.performance_metrics["ttft"])
|
||||
/ len(self.performance_metrics["ttft"]),
|
||||
"p90_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
int(0.9 * len(self.performance_metrics["ttft"]))
|
||||
],
|
||||
"median_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
len(self.performance_metrics["ttft"]) // 2
|
||||
],
|
||||
"average_latency": sum(self.performance_metrics["latency"])
|
||||
/ len(self.performance_metrics["latency"]),
|
||||
"p90_latency": sorted(self.performance_metrics["latency"])[
|
||||
int(0.9 * len(self.performance_metrics["latency"]))
|
||||
],
|
||||
"median_latency": sorted(self.performance_metrics["latency"])[
|
||||
len(self.performance_metrics["latency"]) // 2
|
||||
],
|
||||
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
||||
},
|
||||
}
|
||||
print("All requests completed")
|
||||
print("Performance metrics summary:")
|
||||
print(
|
||||
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second"
|
||||
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
|
||||
)
|
||||
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
|
||||
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
|
||||
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
|
||||
print(
|
||||
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
|
||||
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
|
||||
)
|
||||
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
||||
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
||||
print(
|
||||
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
|
||||
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||
)
|
||||
print(
|
||||
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}"
|
||||
)
|
||||
print(
|
||||
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}"
|
||||
)
|
||||
throughput = self.pbar.total / (self.finished_time - self.start_time)
|
||||
print(f"Throughput: {throughput:.2f} requests per second")
|
||||
log_to_jsonl_file(performance_data, args.log_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||
|
||||
for request_rate in range(1, 41, 2):
|
||||
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
|
||||
args.request_rate = request_rate
|
||||
requests.post(flush_cache_url)
|
||||
time.sleep(1)
|
||||
WorkloadGenerator(args).run()
|
||||
|
||||
Reference in New Issue
Block a user