Hierarchical Caching for SGLang (#2693)
Co-authored-by: Wenxuan Tan <wenxuan.tan@wisc.edu> Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
@@ -26,9 +27,15 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--num-clients",
|
||||
type=int,
|
||||
default=200,
|
||||
default=256,
|
||||
help="Number of concurrent clients",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Maximum number of parallel requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-length",
|
||||
type=int,
|
||||
@@ -73,11 +80,17 @@ def parse_args():
|
||||
help="Server port (default: 30000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
help="model path compatible with Hugging Face Transformers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
type=str,
|
||||
default="performance_metrics.jsonl",
|
||||
help="File to log performance metrics",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -158,6 +171,18 @@ def gen_payload(prompt, output_len):
|
||||
return payload
|
||||
|
||||
|
||||
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl"):
|
||||
"""Append the data with a timestamp to the specified JSONL file."""
|
||||
timestamped_data = {"timestamp": datetime.now().isoformat(), **data}
|
||||
try:
|
||||
with open(file_path, "a") as file:
|
||||
file.write(
|
||||
json.dumps(timestamped_data) + "\n"
|
||||
) # Write as a single line in JSONL format
|
||||
except IOError as e:
|
||||
print(f"Error writing to JSONL file: {e}")
|
||||
|
||||
|
||||
class ReadyQueue:
|
||||
"""
|
||||
Thread-safe queue that can pop requests in different orders based on given policy.
|
||||
@@ -191,12 +216,15 @@ class WorkloadGenerator:
|
||||
# Construct the base URL for requests
|
||||
self.url = f"http://{args.host}:{args.port}/generate"
|
||||
|
||||
self.tokenizer = get_tokenizer(args.model)
|
||||
self.tokenizer = get_tokenizer(args.model_path)
|
||||
self.distribution = args.distribution
|
||||
self.request_rate = args.request_rate
|
||||
self.start_time = None
|
||||
self.finished_time = None
|
||||
|
||||
self.sent_requests = 0
|
||||
self.completed_requests = 0
|
||||
|
||||
self.candidate_inputs = sample_random_requests(
|
||||
input_len=args.request_length,
|
||||
output_len=args.output_length,
|
||||
@@ -235,6 +263,18 @@ class WorkloadGenerator:
|
||||
def request_sender(self):
|
||||
async def request_loop():
|
||||
while True:
|
||||
if self.sent_requests - self.completed_requests < args.max_parallel:
|
||||
new_request = self.ready_queue.pop()
|
||||
if new_request:
|
||||
asyncio.create_task(self.handle_request(new_request))
|
||||
self.sent_requests += 1
|
||||
else:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
# Calculate Poisson-distributed wait time
|
||||
if self.distribution == "poisson":
|
||||
sleep_time = random.expovariate(self.request_rate)
|
||||
@@ -247,14 +287,6 @@ class WorkloadGenerator:
|
||||
raise ValueError("Invalid distribution type")
|
||||
await asyncio.sleep(sleep_time) # Wait before sending the next request
|
||||
|
||||
new_request = self.ready_queue.pop()
|
||||
# Submit async request
|
||||
if new_request:
|
||||
asyncio.create_task(self.handle_request(new_request))
|
||||
else:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
# Create and run the event loop for asynchronous requests
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
@@ -273,6 +305,7 @@ class WorkloadGenerator:
|
||||
self.client_records[client_id]["round"] += 1
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
self.completed_requests += 1
|
||||
|
||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||
self.client_records[client_id][
|
||||
@@ -301,34 +334,56 @@ class WorkloadGenerator:
|
||||
|
||||
request_thread.join()
|
||||
response_thread.join()
|
||||
|
||||
self.pbar.close()
|
||||
print("All requests completed.")
|
||||
|
||||
performance_data = {
|
||||
"summary": {
|
||||
"total_requests": len(self.performance_metrics["ttft"]),
|
||||
"request_rate": self.request_rate,
|
||||
"average_ttft": sum(self.performance_metrics["ttft"])
|
||||
/ len(self.performance_metrics["ttft"]),
|
||||
"p90_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
int(0.9 * len(self.performance_metrics["ttft"]))
|
||||
],
|
||||
"median_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
len(self.performance_metrics["ttft"]) // 2
|
||||
],
|
||||
"average_latency": sum(self.performance_metrics["latency"])
|
||||
/ len(self.performance_metrics["latency"]),
|
||||
"p90_latency": sorted(self.performance_metrics["latency"])[
|
||||
int(0.9 * len(self.performance_metrics["latency"]))
|
||||
],
|
||||
"median_latency": sorted(self.performance_metrics["latency"])[
|
||||
len(self.performance_metrics["latency"]) // 2
|
||||
],
|
||||
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
||||
},
|
||||
}
|
||||
print("All requests completed")
|
||||
print("Performance metrics summary:")
|
||||
print(
|
||||
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second"
|
||||
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
|
||||
)
|
||||
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
|
||||
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
|
||||
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
|
||||
print(
|
||||
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
|
||||
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
|
||||
)
|
||||
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
||||
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
||||
print(
|
||||
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
|
||||
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||
)
|
||||
print(
|
||||
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}"
|
||||
)
|
||||
print(
|
||||
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}"
|
||||
)
|
||||
throughput = self.pbar.total / (self.finished_time - self.start_time)
|
||||
print(f"Throughput: {throughput:.2f} requests per second")
|
||||
log_to_jsonl_file(performance_data, args.log_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||
|
||||
for request_rate in range(1, 41, 2):
|
||||
for request_rate in [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]:
|
||||
args.request_rate = request_rate
|
||||
requests.post(flush_cache_url)
|
||||
time.sleep(1)
|
||||
WorkloadGenerator(args).run()
|
||||
|
||||
Reference in New Issue
Block a user