2x performance improvement for large prefill & Fix workspace conflicts (#579)

This commit is contained in:
Ying Sheng
2024-07-03 16:14:57 -07:00
committed by GitHub
parent 96c503eb60
commit 2a754e57b0
6 changed files with 88 additions and 25 deletions

View File

@@ -230,7 +230,7 @@ def latency_test(
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
rank_print(f"Prefill. latency: {prefill_latency:6.5f} ms, throughput: {throughput:9.2f} token/s")
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
# Decode
for i in range(output_len):
@@ -241,13 +241,13 @@ def latency_test(
latency = time.time() - tic
tot_latency += latency
throughput = bench_args.batch_size / latency
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} ms, throughput: {throughput:9.2f} token/s")
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
avg_decode_latency = (tot_latency - prefill_latency) / output_len
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} ms, avg throughput: {avg_decode_throughput:9.2f} token/s")
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
rank_print(f"Total. latency: {tot_latency:6.3f} ms, throughput: {throughput:9.2f} token/s")
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
# Warm up
run_once(4)