This commit is contained in:
Ying Sheng
2024-07-05 10:06:17 -07:00
committed by GitHub
parent 5a57b8addd
commit dc1b8bcfaa
21 changed files with 487 additions and 354 deletions

View File

@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][:bench_args.cut_len]
tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
req.prefix_indices = []
req.sampling_params = sampling_params
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
for i in range(len(reqs)):
req = reqs[i]
req.input_ids += input_ids[i][bench_args.cut_len:]
req.input_ids += input_ids[i][bench_args.cut_len :]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, :bench_args.cut_len
i, : bench_args.cut_len
]
return reqs
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None)
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits)
@@ -212,7 +213,9 @@ def latency_test(
# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
# Prepare inputs
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
@@ -232,7 +235,9 @@ def latency_test(
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
# Decode
for i in range(output_len):
@@ -243,13 +248,24 @@ def latency_test(
latency = time.time() - tic
tot_latency += latency
throughput = bench_args.batch_size / latency
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
if i < 5:
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
avg_decode_latency = (tot_latency - prefill_latency) / output_len
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
)
throughput = (
(bench_args.input_len + bench_args.output_len)
* bench_args.batch_size
/ tot_latency
)
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
# Warm up
run_once(4)
@@ -298,4 +314,4 @@ if __name__ == "__main__":
format="%(message)s",
)
main(server_args, bench_args)
main(server_args, bench_args)