Hierarchical Caching for SGLang (#2693)

Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu>
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Zhiqiang Xie
2025-02-23 21:56:30 -08:00
committed by GitHub
parent 4d2a88bdff
commit 6c7a152c5a
7 changed files with 732 additions and 91 deletions

View File

@@ -5,6 +5,7 @@ import queue
import random
import threading
import time
from datetime import datetime
from typing import Optional
import aiohttp
@@ -26,9 +27,15 @@ def parse_args():
parser.add_argument(
"--num-clients",
type=int,
default=200,
default=256,
help="Number of concurrent clients",
)
parser.add_argument(
"--max-parallel",
type=int,
default=128,
help="Maximum number of parallel requests",
)
parser.add_argument(
"--request-length",
type=int,
@@ -73,11 +80,17 @@ def parse_args():
help="Server port (default: 30000)",
)
parser.add_argument(
"--model",
"--model-path",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="model path compatible with Hugging Face Transformers",
)
parser.add_argument(
"--log-file",
type=str,
default="performance_metrics.jsonl",
help="File to log performance metrics",
)
return parser.parse_args()
@@ -158,6 +171,18 @@ def gen_payload(prompt, output_len):
return payload
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
"""Append the data with a timestamp to the specified JSONL file."""
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
try:
with open(file_path, "a") as file:
file.write(
json.dumps(timestamped_data) + "\n"
) # Write as a single line in JSONL format
except IOError as e:
print(f"Error writing to JSONL file: {e}")
class ReadyQueue:
"""
Thread-safe queue that can pop requests in different orders based on given policy.
@@ -191,12 +216,15 @@ class WorkloadGenerator:
# Construct the base URL for requests
self.url = f"http://{args.host}:{args.port}/generate"
self.tokenizer = get_tokenizer(args.model)
self.tokenizer = get_tokenizer(args.model_path)
self.distribution = args.distribution
self.request_rate = args.request_rate
self.start_time = None
self.finished_time = None
self.sent_requests = 0
self.completed_requests = 0
self.candidate_inputs = sample_random_requests(
input_len=args.request_length,
output_len=args.output_length,
@@ -235,6 +263,18 @@ class WorkloadGenerator:
def request_sender(self):
async def request_loop():
while True:
if self.sent_requests - self.completed_requests < args.max_parallel:
new_request = self.ready_queue.pop()
if new_request:
asyncio.create_task(self.handle_request(new_request))
self.sent_requests += 1
else:
await asyncio.sleep(0.05)
continue
if self.pbar.n == self.pbar.total:
break
# Calculate Poisson-distributed wait time
if self.distribution == "poisson":
sleep_time = random.expovariate(self.request_rate)
@@ -247,14 +287,6 @@ class WorkloadGenerator:
raise ValueError("Invalid distribution type")
await asyncio.sleep(sleep_time) # Wait before sending the next request
new_request = self.ready_queue.pop()
# Submit async request
if new_request:
asyncio.create_task(self.handle_request(new_request))
else:
if self.pbar.n == self.pbar.total:
break
# Create and run the event loop for asynchronous requests
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@@ -273,6 +305,7 @@ class WorkloadGenerator:
self.client_records[client_id]["round"] += 1
self.performance_metrics["ttft"].append(response.ttft)
self.performance_metrics["latency"].append(response.latency)
self.completed_requests += 1
if self.client_records[client_id]["round"] < args.num_rounds:
self.client_records[client_id][
@@ -301,34 +334,56 @@ class WorkloadGenerator:
request_thread.join()
response_thread.join()
self.pbar.close()
print("All requests completed.")
performance_data = {
"summary": {
"total_requests": len(self.performance_metrics["ttft"]),
"request_rate": self.request_rate,
"average_ttft": sum(self.performance_metrics["ttft"])
/ len(self.performance_metrics["ttft"]),
"p90_ttft": sorted(self.performance_metrics["ttft"])[
int(0.9 * len(self.performance_metrics["ttft"]))
],
"median_ttft": sorted(self.performance_metrics["ttft"])[
len(self.performance_metrics["ttft"]) // 2
],
"average_latency": sum(self.performance_metrics["latency"])
/ len(self.performance_metrics["latency"]),
"p90_latency": sorted(self.performance_metrics["latency"])[
int(0.9 * len(self.performance_metrics["latency"]))
],
"median_latency": sorted(self.performance_metrics["latency"])[
len(self.performance_metrics["latency"]) // 2
],
"throughput": self.pbar.total / (self.finished_time - self.start_time),
},
}
print("All requests completed")
print("Performance metrics summary:")
print(
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second"
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
)
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
print(
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
)
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
print(
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
)
print(
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}"
)
print(
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}"
)
throughput = self.pbar.total / (self.finished_time - self.start_time)
print(f"Throughput: {throughput:.2f} requests per second")
log_to_jsonl_file(performance_data, args.log_file)
if __name__ == "__main__":
args = parse_args()
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
for request_rate in range(1, 41, 2):
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
args.request_rate = request_rate
requests.post(flush_cache_url)
time.sleep(1)
WorkloadGenerator(args).run()