Fix the port_args in bench_latency (#1597)
This commit is contained in:
@@ -123,11 +123,10 @@ class BenchArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model(server_args, tp_rank):
|
def load_model(server_args, port_args, tp_rank):
|
||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||||
|
|
||||||
port_args = PortArgs.init_new(server_args)
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
server_args.model_path,
|
server_args.model_path,
|
||||||
server_args.trust_remote_code,
|
server_args.trust_remote_code,
|
||||||
@@ -248,13 +247,14 @@ def decode(input_token_ids, batch, model_runner):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def correctness_test(
|
def correctness_test(
|
||||||
server_args,
|
server_args,
|
||||||
|
port_args,
|
||||||
bench_args,
|
bench_args,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
):
|
):
|
||||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
||||||
@@ -362,6 +362,7 @@ def latency_test_run_once(
|
|||||||
|
|
||||||
def latency_test(
|
def latency_test(
|
||||||
server_args,
|
server_args,
|
||||||
|
port_args,
|
||||||
bench_args,
|
bench_args,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
):
|
):
|
||||||
@@ -369,7 +370,7 @@ def latency_test(
|
|||||||
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
|
||||||
|
|
||||||
# Load the model
|
# Load the model
|
||||||
model_runner, tokenizer = load_model(server_args, tp_rank)
|
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||||
|
|
||||||
# Prepare inputs for warm up
|
# Prepare inputs for warm up
|
||||||
reqs = prepare_synthetic_inputs_for_latency_test(
|
reqs = prepare_synthetic_inputs_for_latency_test(
|
||||||
@@ -487,8 +488,10 @@ def main(server_args, bench_args):
|
|||||||
"provide --result-filename for plotting the results"
|
"provide --result-filename for plotting the results"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
port_args = PortArgs.init_new(server_args)
|
||||||
|
|
||||||
if server_args.tp_size == 1:
|
if server_args.tp_size == 1:
|
||||||
work_func(server_args, bench_args, 0)
|
work_func(server_args, port_args, bench_args, 0)
|
||||||
else:
|
else:
|
||||||
workers = []
|
workers = []
|
||||||
for tp_rank in range(server_args.tp_size):
|
for tp_rank in range(server_args.tp_size):
|
||||||
@@ -496,6 +499,7 @@ def main(server_args, bench_args):
|
|||||||
target=work_func,
|
target=work_func,
|
||||||
args=(
|
args=(
|
||||||
server_args,
|
server_args,
|
||||||
|
port_args,
|
||||||
bench_args,
|
bench_args,
|
||||||
tp_rank,
|
tp_rank,
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user