Fix bench latency (#607)

This commit is contained in:
Lianmin Zheng
2024-07-11 14:37:01 -07:00
committed by GitHub
parent ad872feb14
commit d9a6902986

View File

@@ -30,8 +30,10 @@ import argparse
import dataclasses import dataclasses
import logging import logging
import multiprocessing import multiprocessing
import os
import time import time
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -70,6 +72,7 @@ class BenchArgs:
def load_model(server_args, tp_rank): def load_model(server_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path) model_config = ModelConfig(path=server_args.model_path)
model_runner = ModelRunner( model_runner = ModelRunner(
@@ -81,7 +84,7 @@ def load_model(server_args, tp_rank):
nccl_port=28888, nccl_port=28888,
server_args=server_args, server_args=server_args,
) )
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}") rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
tokenizer = get_tokenizer( tokenizer = get_tokenizer(
server_args.tokenizer_path, server_args.tokenizer_path,
tokenizer_mode=server_args.tokenizer_mode, tokenizer_mode=server_args.tokenizer_mode,
@@ -201,7 +204,7 @@ def correctness_test(
# Print # Print
for i in range(len(reqs)): for i in range(len(reqs)):
print(tokenizer.decode(output_ids[i])) rank_print(tokenizer.decode(output_ids[i]))
def latency_test( def latency_test(
@@ -213,7 +216,7 @@ def latency_test(
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, tp_rank) model_runner, tokenizer = load_model(server_args, tp_rank)
print( rank_print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}" f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
) )
@@ -299,6 +302,8 @@ def main(server_args, bench_args):
for proc in workers: for proc in workers:
proc.join() proc.join()
proc.terminate()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()