enhance latency test - part 2 (#915)
This commit is contained in:
@@ -220,6 +220,68 @@ def correctness_test(
|
|||||||
rank_print(tokenizer.decode(output_ids[i]))
|
rank_print(tokenizer.decode(output_ids[i]))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def latency_test_run_once(
|
||||||
|
model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||||
|
):
|
||||||
|
|
||||||
|
# Clear the pools.
|
||||||
|
model_runner.req_to_token_pool.clear()
|
||||||
|
model_runner.token_to_kv_pool.clear()
|
||||||
|
|
||||||
|
measurement_results = {
|
||||||
|
"run_name": "before",
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"input_len": input_len,
|
||||||
|
"output_len": output_len,
|
||||||
|
}
|
||||||
|
|
||||||
|
tot_latency = 0
|
||||||
|
|
||||||
|
# Prefill
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
tic = time.time()
|
||||||
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
prefill_latency = time.time() - tic
|
||||||
|
tot_latency += prefill_latency
|
||||||
|
throughput = input_len * batch_size / prefill_latency
|
||||||
|
rank_print(
|
||||||
|
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||||
|
)
|
||||||
|
measurement_results["prefill_latency"] = prefill_latency
|
||||||
|
measurement_results["prefill_throughput"] = throughput
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
for i in range(output_len):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
tic = time.time()
|
||||||
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
latency = time.time() - tic
|
||||||
|
tot_latency += latency
|
||||||
|
throughput = batch_size / latency
|
||||||
|
if i < 5:
|
||||||
|
rank_print(
|
||||||
|
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||||
|
)
|
||||||
|
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
||||||
|
avg_decode_throughput = batch_size / avg_decode_latency
|
||||||
|
rank_print(
|
||||||
|
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
||||||
|
)
|
||||||
|
measurement_results["avg_decode_latency"] = avg_decode_latency
|
||||||
|
measurement_results["avg_decode_throughput"] = avg_decode_throughput
|
||||||
|
|
||||||
|
throughput = (input_len + output_len) * batch_size / tot_latency
|
||||||
|
rank_print(
|
||||||
|
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||||
|
)
|
||||||
|
measurement_results["total_latency"] = tot_latency
|
||||||
|
measurement_results["total_throughput"] = throughput
|
||||||
|
return measurement_results
|
||||||
|
|
||||||
|
|
||||||
def latency_test(
|
def latency_test(
|
||||||
server_args,
|
server_args,
|
||||||
bench_args,
|
bench_args,
|
||||||
@@ -241,72 +303,23 @@ def latency_test(
|
|||||||
bench_args.batch_size, bench_args.input_len
|
bench_args.batch_size, bench_args.input_len
|
||||||
)
|
)
|
||||||
|
|
||||||
def clear():
|
|
||||||
model_runner.req_to_token_pool.clear()
|
|
||||||
model_runner.token_to_kv_pool.clear()
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def run_once(output_len):
|
|
||||||
measurement_results = {
|
|
||||||
"batch_size": bench_args.batch_size,
|
|
||||||
"output_len": output_len,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Prefill
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tot_latency = 0
|
|
||||||
tic = time.time()
|
|
||||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
prefill_latency = time.time() - tic
|
|
||||||
tot_latency += prefill_latency
|
|
||||||
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
|
||||||
rank_print(
|
|
||||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
|
||||||
)
|
|
||||||
measurement_results["prefill_latency"] = prefill_latency
|
|
||||||
measurement_results["prefill_throughput"] = throughput
|
|
||||||
|
|
||||||
# Decode
|
|
||||||
for i in range(output_len):
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
tic = time.time()
|
|
||||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
latency = time.time() - tic
|
|
||||||
tot_latency += latency
|
|
||||||
throughput = bench_args.batch_size / latency
|
|
||||||
if i < 5:
|
|
||||||
rank_print(
|
|
||||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
|
||||||
)
|
|
||||||
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
|
||||||
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
|
||||||
rank_print(
|
|
||||||
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
|
||||||
)
|
|
||||||
measurement_results["avg_decode_latency"] = avg_decode_latency
|
|
||||||
measurement_results["avg_decode_throughput"] = avg_decode_throughput
|
|
||||||
|
|
||||||
throughput = (
|
|
||||||
(bench_args.input_len + bench_args.output_len)
|
|
||||||
* bench_args.batch_size
|
|
||||||
/ tot_latency
|
|
||||||
)
|
|
||||||
rank_print(
|
|
||||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
|
||||||
)
|
|
||||||
measurement_results["total_latency"] = tot_latency
|
|
||||||
measurement_results["total_throughput"] = throughput
|
|
||||||
return measurement_results
|
|
||||||
|
|
||||||
# Warm up
|
# Warm up
|
||||||
run_once(4)
|
latency_test_run_once(
|
||||||
clear()
|
model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
|
||||||
|
)
|
||||||
|
|
||||||
# Run again
|
# Run again
|
||||||
result_list = []
|
result_list = []
|
||||||
result_list.append(run_once(bench_args.output_len))
|
result_list.append(
|
||||||
|
latency_test_run_once(
|
||||||
|
model_runner,
|
||||||
|
rank_print,
|
||||||
|
reqs,
|
||||||
|
bench_args.batch_size,
|
||||||
|
bench_args.input_len,
|
||||||
|
bench_args.output_len,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Write results in jsonlines format.
|
# Write results in jsonlines format.
|
||||||
if bench_args.result_filename:
|
if bench_args.result_filename:
|
||||||
|
|||||||
Reference in New Issue
Block a user