Fix bench latency (#607)
This commit is contained in:
@@ -30,8 +30,10 @@ import argparse
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -70,6 +72,7 @@ class BenchArgs:
|
||||
|
||||
def load_model(server_args, tp_rank):
|
||||
suppress_other_loggers()
|
||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||
|
||||
model_config = ModelConfig(path=server_args.model_path)
|
||||
model_runner = ModelRunner(
|
||||
@@ -81,7 +84,7 @@ def load_model(server_args, tp_rank):
|
||||
nccl_port=28888,
|
||||
server_args=server_args,
|
||||
)
|
||||
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||
tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
tokenizer_mode=server_args.tokenizer_mode,
|
||||
@@ -201,7 +204,7 @@ def correctness_test(
|
||||
|
||||
# Print
|
||||
for i in range(len(reqs)):
|
||||
print(tokenizer.decode(output_ids[i]))
|
||||
rank_print(tokenizer.decode(output_ids[i]))
|
||||
|
||||
|
||||
def latency_test(
|
||||
@@ -213,7 +216,7 @@ def latency_test(
|
||||
|
||||
# Load the model
|
||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||
print(
|
||||
rank_print(
|
||||
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
||||
)
|
||||
|
||||
@@ -299,6 +302,8 @@ def main(server_args, bench_args):
|
||||
for proc in workers:
|
||||
proc.join()
|
||||
|
||||
proc.terminate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
Reference in New Issue
Block a user