sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
517
benchmark/hicache/bench_multiturn.py
Normal file
517
benchmark/hicache/bench_multiturn.py
Normal file
@@ -0,0 +1,517 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import requests
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from sglang.bench_serving import (
|
||||
RequestFuncOutput,
|
||||
get_tokenizer,
|
||||
remove_prefix,
|
||||
sample_random_requests,
|
||||
)
|
||||
|
||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=20 * 60 * 60)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Script to benchmark concurrent requests to a server."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-clients",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Number of concurrent clients",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-parallel",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Maximum number of parallel requests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-length",
|
||||
type=int,
|
||||
default=512,
|
||||
help="Length of each new request",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-length",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Length of each output",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-rounds",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of rounds per client",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--distribution",
|
||||
type=str,
|
||||
default="poisson",
|
||||
choices=["poisson", "uniform"],
|
||||
help="Distribution type for request intervals (poisson or uniform)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--request-rate",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Average number of requests per second",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Server hostname or IP (default: localhost)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=30000,
|
||||
help="Server port (default: 30000)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-path",
|
||||
type=str,
|
||||
default="meta-llama/Llama-3.1-8B-Instruct",
|
||||
help="model path compatible with Hugging Face Transformers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="local dataset to sample tokens from",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-file",
|
||||
type=str,
|
||||
default="performance_metrics.jsonl",
|
||||
help="File to log performance metrics",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-auto-run",
|
||||
action="store_true",
|
||||
help="If set, disable automatically testing with a range of request rates.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-random-sample",
|
||||
action="store_true",
|
||||
help="If set, disable random sampling of requests from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sub-question-input-length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Length of the sub question input for each request, if set 0 use request_length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ready-queue-policy",
|
||||
type=str,
|
||||
default="random",
|
||||
help="Policy for popping requests from the ready queue (random or fifo)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tag",
|
||||
type=str,
|
||||
default="",
|
||||
help="Tag of a certain run in the log file",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--lora-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="String of LoRA path. Currently we only support benchmarking on a single LoRA adaptor.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
async def async_request_sglang_generate(
|
||||
payload,
|
||||
url,
|
||||
pbar: Optional[tqdm] = None,
|
||||
):
|
||||
"""
|
||||
Sends a streaming request to the server. Gathers text token-by-token.
|
||||
"""
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
headers = {}
|
||||
generated_text = ""
|
||||
ttft = 0.0
|
||||
st = time.perf_counter()
|
||||
most_recent_timestamp = st
|
||||
output = RequestFuncOutput()
|
||||
|
||||
try:
|
||||
async with session.post(url=url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
prompt_tokens = 0
|
||||
cached_tokens = 0
|
||||
async for chunk_bytes in response.content:
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
|
||||
chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
|
||||
latency = time.perf_counter() - st
|
||||
if chunk == "[DONE]":
|
||||
pass
|
||||
else:
|
||||
data = json.loads(chunk)
|
||||
|
||||
if data["text"]:
|
||||
timestamp = time.perf_counter()
|
||||
# First token
|
||||
if ttft == 0.0:
|
||||
ttft = time.perf_counter() - st
|
||||
output.ttft = ttft
|
||||
prompt_tokens = (data.get("meta_info") or {}).get(
|
||||
"prompt_tokens", 0
|
||||
)
|
||||
cached_tokens = (data.get("meta_info") or {}).get(
|
||||
"cached_tokens", 0
|
||||
)
|
||||
|
||||
# Decoding phase
|
||||
else:
|
||||
output.itl.append(timestamp - most_recent_timestamp)
|
||||
|
||||
most_recent_timestamp = timestamp
|
||||
generated_text = data["text"]
|
||||
|
||||
output.generated_text = generated_text
|
||||
output.success = True
|
||||
output.latency = latency
|
||||
output.prompt_len = prompt_tokens
|
||||
output.cached_tokens = cached_tokens
|
||||
output.generated_len = len(output.itl) + 1
|
||||
else:
|
||||
output.error = response.reason or ""
|
||||
output.success = False
|
||||
except Exception as e:
|
||||
output.success = False
|
||||
output.error = str(e)
|
||||
print(f"Request failed: {e}")
|
||||
|
||||
if pbar:
|
||||
pbar.update(1)
|
||||
return output
|
||||
|
||||
|
||||
def gen_payload(prompt, output_len, lora_path=""):
|
||||
payload = {
|
||||
"text": prompt,
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": output_len,
|
||||
"ignore_eos": True,
|
||||
},
|
||||
"stream": True,
|
||||
"stream_options": {"include_usage": True},
|
||||
"lora_path": lora_path,
|
||||
"return_logprob": False,
|
||||
"logprob_start_len": -1,
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
def log_to_jsonl_file(data, file_path="performance_metrics.jsonl", tag=""):
|
||||
"""Append the data with a timestamp and tag to the specified JSONL file."""
|
||||
timestamped_data = {"timestamp": datetime.now().isoformat(), "tag": tag, **data}
|
||||
try:
|
||||
with open(file_path, "a") as file:
|
||||
file.write(
|
||||
json.dumps(timestamped_data) + "\n"
|
||||
) # Write as a single line in JSONL format
|
||||
except IOError as e:
|
||||
print(f"Error writing to JSONL file: {e}")
|
||||
|
||||
|
||||
class ReadyQueue:
|
||||
"""
|
||||
Thread-safe queue that can pop requests in different orders based on given policy.
|
||||
"""
|
||||
|
||||
def __init__(self, init_requests=None, policy="random"):
|
||||
self.lock = threading.Lock()
|
||||
self.requests = init_requests or []
|
||||
self.policy = policy
|
||||
|
||||
def append(self, item):
|
||||
with self.lock:
|
||||
self.requests.append(item)
|
||||
|
||||
def pop(self):
|
||||
with self.lock:
|
||||
if not self.requests:
|
||||
return None
|
||||
if self.policy == "random":
|
||||
index = random.randrange(len(self.requests))
|
||||
return self.requests.pop(index)
|
||||
elif self.policy == "fifo":
|
||||
return self.requests.pop(0)
|
||||
else:
|
||||
# todo, varying thinking time of clients
|
||||
raise ValueError(f"{self.policy} not implemented")
|
||||
|
||||
|
||||
class WorkloadGenerator:
|
||||
def __init__(self, args):
|
||||
# Construct the base URL for requests
|
||||
self.url = f"http://{args.host}:{args.port}/generate"
|
||||
|
||||
self.tokenizer = get_tokenizer(args.model_path)
|
||||
self.distribution = args.distribution
|
||||
self.request_rate = args.request_rate
|
||||
self.start_time = None
|
||||
self.finished_time = None
|
||||
|
||||
self.sent_requests = 0
|
||||
self.completed_requests = 0
|
||||
|
||||
self.candidate_inputs = sample_random_requests(
|
||||
input_len=args.request_length,
|
||||
output_len=args.output_length,
|
||||
num_prompts=args.num_clients,
|
||||
range_ratio=1.0,
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=not args.disable_random_sample,
|
||||
)
|
||||
self.candidate_inputs = [i.prompt for i in self.candidate_inputs]
|
||||
|
||||
if args.sub_question_input_length != 0:
|
||||
sub_question_input_length = args.sub_question_input_length
|
||||
else:
|
||||
sub_question_input_length = args.request_length
|
||||
|
||||
self.sub_question_inputs = sample_random_requests(
|
||||
input_len=sub_question_input_length,
|
||||
output_len=args.output_length,
|
||||
num_prompts=args.num_clients * max(args.num_rounds - 1, 1),
|
||||
range_ratio=1.0,
|
||||
tokenizer=self.tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
random_sample=not args.disable_random_sample,
|
||||
)
|
||||
|
||||
init_requests = [
|
||||
(
|
||||
i,
|
||||
gen_payload(
|
||||
self.candidate_inputs[i], args.output_length, args.lora_path
|
||||
),
|
||||
)
|
||||
for i in range(args.num_clients)
|
||||
]
|
||||
self.client_records = {
|
||||
i: {"round": 0, "history": init_requests[i][1]["text"]}
|
||||
for i in range(args.num_clients)
|
||||
}
|
||||
self.ready_queue = ReadyQueue(
|
||||
init_requests=init_requests, policy=args.ready_queue_policy
|
||||
)
|
||||
self.candidate_inputs = self.candidate_inputs[args.num_clients :]
|
||||
|
||||
self.response_queue = queue.Queue()
|
||||
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
||||
self.performance_metrics = {
|
||||
"ttft": [],
|
||||
"latency": [],
|
||||
"prompt_len": [],
|
||||
"cached_tokens": [],
|
||||
"generated_len": [],
|
||||
}
|
||||
self.num_rounds = args.num_rounds
|
||||
self.max_parallel = args.max_parallel
|
||||
self.output_length = args.output_length
|
||||
|
||||
async def handle_request(self, item):
|
||||
try:
|
||||
client_id, payload = item
|
||||
response = await async_request_sglang_generate(payload, self.url, self.pbar)
|
||||
if self.pbar.n == self.pbar.total:
|
||||
self.finished_time = time.perf_counter()
|
||||
self.response_queue.put((client_id, response))
|
||||
except Exception as e:
|
||||
print(f"Request failed: {e}")
|
||||
|
||||
def request_sender(self):
|
||||
async def request_loop():
|
||||
while True:
|
||||
if self.sent_requests - self.completed_requests < self.max_parallel:
|
||||
new_request = self.ready_queue.pop()
|
||||
if new_request:
|
||||
asyncio.create_task(self.handle_request(new_request))
|
||||
self.sent_requests += 1
|
||||
else:
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
|
||||
# Calculate Poisson-distributed wait time
|
||||
if self.distribution == "poisson":
|
||||
sleep_time = random.expovariate(self.request_rate)
|
||||
elif self.distribution == "uniform":
|
||||
avg_interval = (
|
||||
1.0 / self.request_rate if self.request_rate > 0 else 1.0
|
||||
)
|
||||
sleep_time = random.uniform(0, 2 * avg_interval)
|
||||
else:
|
||||
raise ValueError("Invalid distribution type")
|
||||
await asyncio.sleep(sleep_time) # Wait before sending the next request
|
||||
|
||||
# Create and run the event loop for asynchronous requests
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(request_loop())
|
||||
loop.close()
|
||||
|
||||
def response_handler(self):
|
||||
while True:
|
||||
try:
|
||||
client_id, response = self.response_queue.get(
|
||||
timeout=10
|
||||
) # Block until response is available
|
||||
if not response.success:
|
||||
raise ValueError(f"Request failed with error: {response.error}")
|
||||
self.client_records[client_id]["history"] += response.generated_text
|
||||
self.client_records[client_id]["round"] += 1
|
||||
self.performance_metrics["ttft"].append(response.ttft)
|
||||
self.performance_metrics["latency"].append(response.latency)
|
||||
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||
self.performance_metrics["generated_len"].append(response.generated_len)
|
||||
self.completed_requests += 1
|
||||
|
||||
if self.client_records[client_id]["round"] < self.num_rounds:
|
||||
# append new request to client's history
|
||||
self.client_records[client_id][
|
||||
"history"
|
||||
] += self.sub_question_inputs.pop().prompt
|
||||
self.ready_queue.append(
|
||||
(
|
||||
client_id,
|
||||
gen_payload(
|
||||
self.client_records[client_id]["history"],
|
||||
self.output_length,
|
||||
args.lora_path,
|
||||
),
|
||||
)
|
||||
)
|
||||
except queue.Empty:
|
||||
if self.pbar.n == self.pbar.total:
|
||||
break
|
||||
except ValueError as e:
|
||||
print(f"Error processing response for client {client_id}: {e}")
|
||||
continue
|
||||
|
||||
def run(self):
|
||||
request_thread = threading.Thread(target=self.request_sender, daemon=True)
|
||||
response_thread = threading.Thread(target=self.response_handler, daemon=True)
|
||||
|
||||
self.start_time = time.perf_counter()
|
||||
request_thread.start()
|
||||
response_thread.start()
|
||||
|
||||
request_thread.join()
|
||||
response_thread.join()
|
||||
self.pbar.close()
|
||||
|
||||
duration = self.finished_time - self.start_time
|
||||
performance_data = {
|
||||
"summary": {
|
||||
"total_requests": len(self.performance_metrics["ttft"]),
|
||||
"request_rate": self.request_rate,
|
||||
"average_ttft": sum(self.performance_metrics["ttft"])
|
||||
/ len(self.performance_metrics["ttft"]),
|
||||
"p90_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
int(0.9 * len(self.performance_metrics["ttft"]))
|
||||
],
|
||||
"median_ttft": sorted(self.performance_metrics["ttft"])[
|
||||
len(self.performance_metrics["ttft"]) // 2
|
||||
],
|
||||
"average_latency": sum(self.performance_metrics["latency"])
|
||||
/ len(self.performance_metrics["latency"]),
|
||||
"p90_latency": sorted(self.performance_metrics["latency"])[
|
||||
int(0.9 * len(self.performance_metrics["latency"]))
|
||||
],
|
||||
"median_latency": sorted(self.performance_metrics["latency"])[
|
||||
len(self.performance_metrics["latency"]) // 2
|
||||
],
|
||||
"input_token_throughput": sum(self.performance_metrics["prompt_len"])
|
||||
/ duration,
|
||||
"output_token_throughput": sum(
|
||||
self.performance_metrics["generated_len"]
|
||||
)
|
||||
/ duration,
|
||||
"throughput": self.pbar.total / duration,
|
||||
"cache_hit_rate": (
|
||||
0
|
||||
if sum(self.performance_metrics["prompt_len"]) == 0
|
||||
else sum(self.performance_metrics["cached_tokens"])
|
||||
/ sum(self.performance_metrics["prompt_len"])
|
||||
),
|
||||
},
|
||||
}
|
||||
print("All requests completed")
|
||||
print("Performance metrics summary:")
|
||||
print(
|
||||
f" Total requests: {performance_data['summary']['total_requests']} at {performance_data['summary']['request_rate']} requests per second"
|
||||
)
|
||||
print(f" Average TTFT: {performance_data['summary']['average_ttft']:.2f}")
|
||||
print(f" P90 TTFT: {performance_data['summary']['p90_ttft']:.2f}")
|
||||
print(f" Median TTFT: {performance_data['summary']['median_ttft']:.2f}")
|
||||
print(
|
||||
f" Average latency: {performance_data['summary']['average_latency']:.2f}"
|
||||
)
|
||||
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
||||
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
||||
print(
|
||||
f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second"
|
||||
)
|
||||
print(
|
||||
f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second"
|
||||
)
|
||||
print(
|
||||
f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||
)
|
||||
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
||||
return performance_data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
if args.disable_auto_run:
|
||||
print("Running with specified request rate...")
|
||||
request_rates = [args.request_rate]
|
||||
else:
|
||||
print("Auto-running with different request rates...")
|
||||
request_rates = [16, 14, 12, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1]
|
||||
|
||||
for rate in request_rates:
|
||||
args.request_rate = rate
|
||||
requests.post(flush_cache_url)
|
||||
time.sleep(1)
|
||||
performance_data = WorkloadGenerator(args).run()
|
||||
log_to_jsonl_file(performance_data, args.log_file, tag=args.tag)
|
||||
Reference in New Issue
Block a user