Multi-turn benchmark for hierarchical caching (#2942)
This commit is contained in:
334
benchmark/hicache/bench_multiturn.py
Normal file
334
benchmark/hicache/bench_multiturn.py
Normal file
@@ -0,0 +1,334 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from sglang.bench_serving import (
|
||||
RequestFuncOutput,
|
||||
get_tokenizer,
|
||||
remove_prefix,
|
||||
sample_random_requests,
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to benchmark concurrent requests to a server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-clients",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of concurrent clients",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Length of each new request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Length of each output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-rounds",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of rounds per client",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distribution",
|
||||
type=str,
|
||||
default="poisson",
|
||||
choices=["poisson", "uniform"],
|
||||
help="Distribution type for request intervals (poisson or uniform)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Average number of requests per second",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Server hostname or IP (default: localhost)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=30000,
|
||||
help="Server port (default: 30000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
type=str,
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
help="model path compatible with Hugging Face Transformers",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def async_request_sglang_generate(
|
||||
payload,
|
||||
url,
|
||||
pbar: Optional[tqdm] = None,
|
||||
):
|
||||
"""
|
||||
Sends a streaming request to the server. Gathers text token-by-token.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
headers = {}
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
output = RequestFuncOutput()
|
||||
|
||||
try:
|
||||
async with session.post(url=url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
if data["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text = data["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
print(f"Request failed: {e}")
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def gen_payload(prompt, output_len):
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"stream": True,
|
||||
"lora_path": "",
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": -1,
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
class ReadyQueue:
|
||||
"""
|
||||
Thread-safe queue that can pop requests in different orders based on given policy.
|
||||
"""
|
||||
|
||||
def __init__(self, init_requests=None, policy="random"):
|
||||
self.lock = threading.Lock()
|
||||
self.requests = init_requests or []
|
||||
self.policy = policy
|
||||
|
||||
def append(self, item):
|
||||
with self.lock:
|
||||
self.requests.append(item)
|
||||
|
||||
def pop(self):
|
||||
with self.lock:
|
||||
if not self.requests:
|
||||
return None
|
||||
if self.policy == "random":
|
||||
index = random.randrange(len(self.requests))
|
||||
return self.requests.pop(index)
|
||||
elif self.policy == "fifo":
|
||||
return self.requests.pop(0)
|
||||
else:
|
||||
# todo, varying thinking time of clients
|
||||
raise ValueError(f"{self.policy} not implemented")
|
||||
|
||||
|
||||
class WorkloadGenerator:
|
||||
def __init__(self, args):
|
||||
# Construct the base URL for requests
|
||||
self.url = f"http://{args.host}:{args.port}/generate"
|
||||
|
||||
self.tokenizer = get_tokenizer(args.model)
|
||||
self.distribution = args.distribution
|
||||
self.request_rate = args.request_rate
|
||||
self.start_time = None
|
||||
self.finished_time = None
|
||||
|
||||
self.candidate_inputs = sample_random_requests(
|
||||
input_len=args.request_length,
|
||||
output_len=args.output_length,
|
||||
num_prompts=args.num_clients * args.num_rounds,
|
||||
range_ratio=1.0,
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path="",
|
||||
)
|
||||
self.candidate_inputs = [i[0] for i in self.candidate_inputs]
|
||||
|
||||
init_requests = [
|
||||
(i, gen_payload(self.candidate_inputs[i], args.output_length))
|
||||
for i in range(args.num_clients)
|
||||
]
|
||||
self.client_records = {
|
||||
i: {"round": 0, "history": init_requests[i][1]["text"]}
|
||||
for i in range(args.num_clients)
|
||||
}
|
||||
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
||||
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
||||
|
||||
self.response_queue = queue.Queue()
|
||||
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
||||
self.performance_metrics = {"ttft": [], "latency": []}
|
||||
|
||||
async def handle_request(self, item):
|
||||
try:
|
||||
client_id, payload = item
|
||||
response = await async_request_sglang_generate(payload, self.url, self.pbar)
|
||||
if self.pbar.n == self.pbar.total:
|
||||
self.finished_time = time.time()
|
||||
self.response_queue.put((client_id, response))
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
|
||||
def request_sender(self):
|
||||
async def request_loop():
|
||||
while True:
|
||||
# Calculate Poisson-distributed wait time
|
||||
if self.distribution == "poisson":
|
||||
sleep_time = random.expovariate(self.request_rate)
|
||||
elif self.distribution == "uniform":
|
||||
avg_interval = (
|
||||
1.0 / self.request_rate if self.request_rate > 0 else 1.0
|
||||
)
|
||||
sleep_time = random.uniform(0, 2 * avg_interval)
|
||||
else:
|
||||
raise ValueError("Invalid distribution type")
|
||||
await asyncio.sleep(sleep_time) # Wait before sending the next request
|
||||
|
||||
new_request = self.ready_queue.pop()
|
||||
# Submit async request
|
||||
if new_request:
|
||||
asyncio.create_task(self.handle_request(new_request))
|
||||
else:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
# Create and run the event loop for asynchronous requests
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(request_loop())
|
||||
loop.close()
|
||||
|
||||
def response_handler(self):
|
||||
while True:
|
||||
try:
|
||||
client_id, response = self.response_queue.get(
|
||||
timeout=10
|
||||
) # Block until response is available
|
||||
if not response.success:
|
||||
raise ValueError(f"Request failed with error: {response.error}")
|
||||
self.client_records[client_id]["history"] += response.generated_text
|
||||
self.client_records[client_id]["round"] += 1
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
|
||||
if self.client_records[client_id]["round"] < args.num_rounds:
|
||||
self.client_records[client_id][
|
||||
"history"
|
||||
] += self.candidate_inputs.pop()
|
||||
self.ready_queue.append(
|
||||
(
|
||||
client_id,
|
||||
gen_payload(
|
||||
self.client_records[client_id]["history"],
|
||||
args.output_length,
|
||||
),
|
||||
)
|
||||
)
|
||||
except queue.Empty:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
def run(self):
|
||||
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
||||
response_thread = threading.Thread(target=self.response_handler, daemon=True)
|
||||
|
||||
self.start_time = time.time()
|
||||
request_thread.start()
|
||||
response_thread.start()
|
||||
|
||||
request_thread.join()
|
||||
response_thread.join()
|
||||
|
||||
self.pbar.close()
|
||||
print("All requests completed.")
|
||||
print("Performance metrics summary:")
|
||||
print(
|
||||
f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second"
|
||||
)
|
||||
print(
|
||||
f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}"
|
||||
)
|
||||
print(
|
||||
f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}"
|
||||
)
|
||||
print(
|
||||
f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}"
|
||||
)
|
||||
print(
|
||||
f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}"
|
||||
)
|
||||
throughput = self.pbar.total / (self.finished_time - self.start_time)
|
||||
print(f"Throughput: {throughput:.2f} requests per second")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||
|
||||
for request_rate in range(1, 41, 2):
|
||||
args.request_rate = request_rate
|
||||
requests.post(flush_cache_url)
|
||||
WorkloadGenerator(args).run()
|
||||
Reference in New Issue
Block a user