Fix bench latency (#607)
This commit is contained in:
@@ -30,8 +30,10 @@ import argparse
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -70,6 +72,7 @@ class BenchArgs:
|
|||||||
|
|
||||||
def load_model(server_args, tp_rank):
|
def load_model(server_args, tp_rank):
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||||
|
|
||||||
model_config = ModelConfig(path=server_args.model_path)
|
model_config = ModelConfig(path=server_args.model_path)
|
||||||
model_runner = ModelRunner(
|
model_runner = ModelRunner(
|
||||||
@@ -81,7 +84,7 @@ def load_model(server_args, tp_rank):
|
|||||||
nccl_port=28888,
|
nccl_port=28888,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
)
|
)
|
||||||
print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
|
||||||
tokenizer = get_tokenizer(
|
tokenizer = get_tokenizer(
|
||||||
server_args.tokenizer_path,
|
server_args.tokenizer_path,
|
||||||
tokenizer_mode=server_args.tokenizer_mode,
|
tokenizer_mode=server_args.tokenizer_mode,
|
||||||
@@ -201,7 +204,7 @@ def correctness_test(
|
|||||||
|
|
||||||
# Print
|
# Print
|
||||||
for i in range(len(reqs)):
|
for i in range(len(reqs)):
|
||||||
print(tokenizer.decode(output_ids[i]))
|
rank_print(tokenizer.decode(output_ids[i]))
|
||||||
|
|
||||||
|
|
||||||
def latency_test(
|
def latency_test(
|
||||||
@@ -213,7 +216,7 @@ def latency_test(
|
|||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
model_runner, tokenizer = load_model(server_args, tp_rank)
|
||||||
print(
|
rank_print(
|
||||||
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -299,6 +302,8 @@ def main(server_args, bench_args):
|
|||||||
for proc in workers:
|
for proc in workers:
|
||||||
proc.join()
|
proc.join()
|
||||||
|
|
||||||
|
proc.terminate()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|||||||
Reference in New Issue
Block a user