Format (#593)
This commit is contained in:
@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
|
||||
for i in range(len(prompts)):
|
||||
assert len(input_ids[i]) > bench_args.cut_len
|
||||
|
||||
tmp_input_ids = input_ids[i][:bench_args.cut_len]
|
||||
tmp_input_ids = input_ids[i][: bench_args.cut_len]
|
||||
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
|
||||
req.prefix_indices = []
|
||||
req.sampling_params = sampling_params
|
||||
@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
|
||||
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
|
||||
for i in range(len(reqs)):
|
||||
req = reqs[i]
|
||||
req.input_ids += input_ids[i][bench_args.cut_len:]
|
||||
req.input_ids += input_ids[i][bench_args.cut_len :]
|
||||
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
|
||||
i, :bench_args.cut_len
|
||||
i, : bench_args.cut_len
|
||||
]
|
||||
return reqs
|
||||
|
||||
@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
tree_cache=None)
|
||||
tree_cache=None,
|
||||
)
|
||||
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
|
||||
output = model_runner.forward(batch, ForwardMode.EXTEND)
|
||||
next_token_ids, _ = batch.sample(output.next_token_logits)
|
||||
@@ -212,7 +213,9 @@ def latency_test(
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||
print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
|
||||
print(
|
||||
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
||||
)
|
||||
|
||||
# Prepare inputs
|
||||
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
|
||||
@@ -232,7 +235,9 @@ def latency_test(
|
||||
prefill_latency = time.time() - tic
|
||||
tot_latency += prefill_latency
|
||||
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
|
||||
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||
rank_print(
|
||||
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
# Decode
|
||||
for i in range(output_len):
|
||||
@@ -243,13 +248,24 @@ def latency_test(
|
||||
latency = time.time() - tic
|
||||
tot_latency += latency
|
||||
throughput = bench_args.batch_size / latency
|
||||
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
|
||||
if i < 5:
|
||||
rank_print(
|
||||
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
avg_decode_latency = (tot_latency - prefill_latency) / output_len
|
||||
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
|
||||
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
|
||||
|
||||
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
|
||||
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
|
||||
rank_print(
|
||||
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
throughput = (
|
||||
(bench_args.input_len + bench_args.output_len)
|
||||
* bench_args.batch_size
|
||||
/ tot_latency
|
||||
)
|
||||
rank_print(
|
||||
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
# Warm up
|
||||
run_once(4)
|
||||
@@ -298,4 +314,4 @@ if __name__ == "__main__":
|
||||
format="%(message)s",
|
||||
)
|
||||
|
||||
main(server_args, bench_args)
|
||||
main(server_args, bench_args)
|
||||
|
||||
Reference in New Issue
Block a user