Fix bench latency (#607)

This commit is contained in:
Lianmin Zheng
2024-07-11 14:37:01 -07:00
committed by GitHub
parent ad872feb14
commit d9a6902986

View File

@@ -30,8 +30,10 @@ import argparse
import dataclasses
import logging
import multiprocessing
import os
import time
import numpy as np
import torch
import torch.distributed as dist
@@ -70,6 +72,7 @@ class BenchArgs:
def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path)
model_runner = ModelRunner(
@@ -81,7 +84,7 @@ def load_model(server_args, tp_rank):
nccl_port=28888,
server_args=server_args,
)
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer(
server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode,
@@ -201,7 +204,7 @@ def correctness_test(
# Print
for i in range(len(reqs)):
print(tokenizer.decode(output_ids[i]))
rank_print(tokenizer.decode(output_ids[i]))
def latency_test(
@@ -213,7 +216,7 @@ def latency_test(
# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
print(
rank_print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
@@ -299,6 +302,8 @@ def main(server_args, bench_args):
for proc in workers:
proc.join()
proc.terminate()
if __name__ == "__main__":
parser = argparse.ArgumentParser()