enhance latency test - part 2 (#915)
This commit is contained in:
@@ -220,6 +220,68 @@ def correctness_test(
|
||||
rank_print(tokenizer.decode(output_ids[i]))
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def latency_test_run_once(
|
||||
model_runner, rank_print, reqs, batch_size, input_len, output_len
|
||||
):
|
||||
|
||||
# Clear the pools.
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
measurement_results = {
|
||||
"run_name": "before",
|
||||
"batch_size": batch_size,
|
||||
"input_len": input_len,
|
||||
"output_len": output_len,
|
||||
}
|
||||
|
||||
tot_latency = 0
|
||||
|
||||
# Prefill
|
||||
torch.cuda.synchronize()
|
||||
tic = time.time()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = input_len * batch_size / prefill_latency
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
# Decode
|
||||
for i in range(output_len):
|
||||
torch.cuda.synchronize()
|
||||
tic = time.time()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = batch_size / latency
|
||||
if i < 5:
|
||||
rank_print(
|
||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
||||
avg_decode_throughput = batch_size / avg_decode_latency
|
||||
rank_print(
|
||||
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["avg_decode_latency"] = avg_decode_latency
|
||||
measurement_results["avg_decode_throughput"] = avg_decode_throughput
|
||||
|
||||
throughput = (input_len + output_len) * batch_size / tot_latency
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["total_latency"] = tot_latency
|
||||
measurement_results["total_throughput"] = throughput
|
||||
return measurement_results
|
||||
|
||||
|
||||
def latency_test(
|
||||
server_args,
|
||||
bench_args,
|
||||
@@ -241,72 +303,23 @@ def latency_test(
|
||||
bench_args.batch_size, bench_args.input_len
|
||||
)
|
||||
|
||||
def clear():
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
|
||||
@torch.inference_mode()
|
||||
def run_once(output_len):
|
||||
measurement_results = {
|
||||
"batch_size": bench_args.batch_size,
|
||||
"output_len": output_len,
|
||||
}
|
||||
|
||||
# Prefill
|
||||
torch.cuda.synchronize()
|
||||
tot_latency = 0
|
||||
tic = time.time()
|
||||
next_token_ids, _, batch = extend(reqs, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
# Decode
|
||||
for i in range(output_len):
|
||||
torch.cuda.synchronize()
|
||||
tic = time.time()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
torch.cuda.synchronize()
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = bench_args.batch_size / latency
|
||||
if i < 5:
|
||||
rank_print(
|
||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
||||
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
||||
rank_print(
|
||||
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["avg_decode_latency"] = avg_decode_latency
|
||||
measurement_results["avg_decode_throughput"] = avg_decode_throughput
|
||||
|
||||
throughput = (
|
||||
(bench_args.input_len + bench_args.output_len)
|
||||
* bench_args.batch_size
|
||||
/ tot_latency
|
||||
)
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
measurement_results["total_latency"] = tot_latency
|
||||
measurement_results["total_throughput"] = throughput
|
||||
return measurement_results
|
||||
|
||||
# Warm up
|
||||
run_once(4)
|
||||
clear()
|
||||
latency_test_run_once(
|
||||
model_runner, rank_print, reqs, bench_args.batch_size, bench_args.input_len, 4
|
||||
)
|
||||
|
||||
# Run again
|
||||
result_list = []
|
||||
result_list.append(run_once(bench_args.output_len))
|
||||
result_list.append(
|
||||
latency_test_run_once(
|
||||
model_runner,
|
||||
rank_print,
|
||||
reqs,
|
||||
bench_args.batch_size,
|
||||
bench_args.input_len,
|
||||
bench_args.output_len,
|
||||
)
|
||||
)
|
||||
|
||||
# Write results in jsonlines format.
|
||||
if bench_args.result_filename:
|
||||
|
||||
Reference in New Issue
Block a user