fix(hicahce-long-bench): adjust context workload generator to use full query set (#9847)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -31,9 +31,10 @@ class ContextWorkloadGenerator(WorkloadGenerator):
|
|||||||
self.completed_requests = 0
|
self.completed_requests = 0
|
||||||
|
|
||||||
self.dataset = json.load(open(args.dataset_path))
|
self.dataset = json.load(open(args.dataset_path))
|
||||||
|
num_requests = min(args.num_clients, len(self.dataset["queries"]))
|
||||||
|
|
||||||
init_requests = []
|
init_requests = []
|
||||||
for i in range(min(args.num_clients, len(self.dataset["queries"]))):
|
for i in range(num_requests):
|
||||||
context_id = self.dataset["queries"][i]["context"]
|
context_id = self.dataset["queries"][i]["context"]
|
||||||
init_requests.append(
|
init_requests.append(
|
||||||
(
|
(
|
||||||
@@ -52,13 +53,14 @@ class ContextWorkloadGenerator(WorkloadGenerator):
|
|||||||
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
self.ready_queue = ReadyQueue(init_requests=init_requests)
|
||||||
|
|
||||||
self.response_queue = queue.Queue()
|
self.response_queue = queue.Queue()
|
||||||
self.pbar = tqdm(total=args.num_clients * args.num_rounds)
|
self.pbar = tqdm(total=num_requests)
|
||||||
self.performance_metrics = {
|
self.performance_metrics = {
|
||||||
"ttft": [],
|
"ttft": [],
|
||||||
"latency": [],
|
"latency": [],
|
||||||
"itl": [],
|
"itl": [],
|
||||||
"prompt_len": [],
|
"prompt_len": [],
|
||||||
"cached_tokens": [],
|
"cached_tokens": [],
|
||||||
|
"generated_len": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
self.max_parallel = args.max_parallel
|
self.max_parallel = args.max_parallel
|
||||||
@@ -75,6 +77,9 @@ class ContextWorkloadGenerator(WorkloadGenerator):
|
|||||||
self.performance_metrics["ttft"].append(response.ttft)
|
self.performance_metrics["ttft"].append(response.ttft)
|
||||||
self.performance_metrics["itl"].extend(response.itl)
|
self.performance_metrics["itl"].extend(response.itl)
|
||||||
self.performance_metrics["latency"].append(response.latency)
|
self.performance_metrics["latency"].append(response.latency)
|
||||||
|
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||||
|
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||||
|
self.performance_metrics["generated_len"].append(response.generated_len)
|
||||||
self.completed_requests += 1
|
self.completed_requests += 1
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
@@ -85,7 +90,7 @@ class ContextWorkloadGenerator(WorkloadGenerator):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
args.num_rounds = 1
|
args.num_rounds = 1
|
||||||
args.max_parallel = 128
|
args.max_parallel = 24
|
||||||
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
flush_cache_url = f"http://{args.host}:{args.port}/flush_cache"
|
||||||
|
|
||||||
for request_rate in [24, 16, 12, 8, 4, 2, 1]:
|
for request_rate in [24, 16, 12, 8, 4, 2, 1]:
|
||||||
|
|||||||
@@ -191,6 +191,7 @@ async def async_request_sglang_generate(
|
|||||||
output.latency = latency
|
output.latency = latency
|
||||||
output.prompt_len = prompt_tokens
|
output.prompt_len = prompt_tokens
|
||||||
output.cached_tokens = cached_tokens
|
output.cached_tokens = cached_tokens
|
||||||
|
output.generated_len = len(output.itl) + 1
|
||||||
else:
|
else:
|
||||||
output.error = response.reason or ""
|
output.error = response.reason or ""
|
||||||
output.success = False
|
output.success = False
|
||||||
@@ -321,6 +322,7 @@ class WorkloadGenerator:
|
|||||||
"latency": [],
|
"latency": [],
|
||||||
"prompt_len": [],
|
"prompt_len": [],
|
||||||
"cached_tokens": [],
|
"cached_tokens": [],
|
||||||
|
"generated_len": [],
|
||||||
}
|
}
|
||||||
self.num_rounds = args.num_rounds
|
self.num_rounds = args.num_rounds
|
||||||
self.max_parallel = args.max_parallel
|
self.max_parallel = args.max_parallel
|
||||||
@@ -383,6 +385,7 @@ class WorkloadGenerator:
|
|||||||
self.performance_metrics["latency"].append(response.latency)
|
self.performance_metrics["latency"].append(response.latency)
|
||||||
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
self.performance_metrics["prompt_len"].append(response.prompt_len)
|
||||||
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
self.performance_metrics["cached_tokens"].append(response.cached_tokens)
|
||||||
|
self.performance_metrics["generated_len"].append(response.generated_len)
|
||||||
self.completed_requests += 1
|
self.completed_requests += 1
|
||||||
|
|
||||||
if self.client_records[client_id]["round"] < self.num_rounds:
|
if self.client_records[client_id]["round"] < self.num_rounds:
|
||||||
@@ -418,6 +421,7 @@ class WorkloadGenerator:
|
|||||||
response_thread.join()
|
response_thread.join()
|
||||||
self.pbar.close()
|
self.pbar.close()
|
||||||
|
|
||||||
|
duration = self.finished_time - self.start_time
|
||||||
performance_data = {
|
performance_data = {
|
||||||
"summary": {
|
"summary": {
|
||||||
"total_requests": len(self.performance_metrics["ttft"]),
|
"total_requests": len(self.performance_metrics["ttft"]),
|
||||||
@@ -438,7 +442,13 @@ class WorkloadGenerator:
|
|||||||
"median_latency": sorted(self.performance_metrics["latency"])[
|
"median_latency": sorted(self.performance_metrics["latency"])[
|
||||||
len(self.performance_metrics["latency"]) // 2
|
len(self.performance_metrics["latency"]) // 2
|
||||||
],
|
],
|
||||||
"throughput": self.pbar.total / (self.finished_time - self.start_time),
|
"input_token_throughput": sum(self.performance_metrics["prompt_len"])
|
||||||
|
/ duration,
|
||||||
|
"output_token_throughput": sum(
|
||||||
|
self.performance_metrics["generated_len"]
|
||||||
|
)
|
||||||
|
/ duration,
|
||||||
|
"throughput": self.pbar.total / duration,
|
||||||
"cache_hit_rate": (
|
"cache_hit_rate": (
|
||||||
0
|
0
|
||||||
if sum(self.performance_metrics["prompt_len"]) == 0
|
if sum(self.performance_metrics["prompt_len"]) == 0
|
||||||
@@ -461,7 +471,13 @@ class WorkloadGenerator:
|
|||||||
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
print(f" P90 latency: {performance_data['summary']['p90_latency']:.2f}")
|
||||||
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
print(f" Median latency: {performance_data['summary']['median_latency']:.2f}")
|
||||||
print(
|
print(
|
||||||
f" Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
f" Input token throughput: {performance_data['summary']['input_token_throughput']:.2f} tokens per second"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Output token throughput: {performance_data['summary']['output_token_throughput']:.2f} tokens per second"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Request Throughput: {performance_data['summary']['throughput']:.2f} requests per second"
|
||||||
)
|
)
|
||||||
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
print(f" Cache Hit Rate: {performance_data['summary']['cache_hit_rate']:.6f}")
|
||||||
return performance_data
|
return performance_data
|
||||||
|
|||||||
Reference in New Issue
Block a user