Enhancements for bench_one_batch (#8703)
Co-authored-by: root <root@gnr630186.jf.intel.com>
This commit is contained in:
@@ -43,6 +43,7 @@ I'm going to the park
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import itertools
|
import itertools
|
||||||
import json
|
import json
|
||||||
@@ -84,12 +85,14 @@ class BenchArgs:
|
|||||||
batch_size: Tuple[int] = (1,)
|
batch_size: Tuple[int] = (1,)
|
||||||
input_len: Tuple[int] = (1024,)
|
input_len: Tuple[int] = (1024,)
|
||||||
output_len: Tuple[int] = (16,)
|
output_len: Tuple[int] = (16,)
|
||||||
|
prompt_filename: str = ""
|
||||||
result_filename: str = "result.jsonl"
|
result_filename: str = "result.jsonl"
|
||||||
correctness_test: bool = False
|
correctness_test: bool = False
|
||||||
# This is only used for correctness test
|
# This is only used for correctness test
|
||||||
cut_len: int = 4
|
cut_len: int = 4
|
||||||
log_decode_step: int = 0
|
log_decode_step: int = 0
|
||||||
profile: bool = False
|
profile: bool = False
|
||||||
|
profile_record_shapes: bool = False
|
||||||
profile_filename_prefix: str = "profile"
|
profile_filename_prefix: str = "profile"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -104,6 +107,9 @@ class BenchArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||||
)
|
)
|
||||||
@@ -118,6 +124,11 @@ class BenchArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile", action="store_true", help="Use Torch Profiler."
|
"--profile", action="store_true", help="Use Torch Profiler."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--profile-record-shapes",
|
||||||
|
action="store_true",
|
||||||
|
help="Record tensor shapes in profiling results.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--profile-filename-prefix",
|
"--profile-filename-prefix",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
|
|||||||
return model_runner, tokenizer
|
return model_runner, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
||||||
prompts = [
|
prompts = (
|
||||||
"The capital of France is",
|
custom_prompts
|
||||||
"The capital of the United Kindom is",
|
if custom_prompts
|
||||||
"Today is a sunny day and I like",
|
else [
|
||||||
]
|
"The capital of France is",
|
||||||
|
"The capital of the United Kindom is",
|
||||||
|
"Today is a sunny day and I like",
|
||||||
|
]
|
||||||
|
)
|
||||||
input_ids = [tokenizer.encode(p) for p in prompts]
|
input_ids = [tokenizer.encode(p) for p in prompts]
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0,
|
temperature=0,
|
||||||
@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
|
|||||||
return reqs
|
return reqs
|
||||||
|
|
||||||
|
|
||||||
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
def prepare_synthetic_inputs_for_latency_test(
|
||||||
input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
batch_size, input_len, custom_inputs=None
|
||||||
|
):
|
||||||
|
input_ids = (
|
||||||
|
custom_inputs
|
||||||
|
if custom_inputs
|
||||||
|
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
||||||
|
)
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_new_tokens=BenchArgs.output_len,
|
max_new_tokens=BenchArgs.output_len,
|
||||||
@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_prompts_from_file(prompt_file, rank_print):
|
||||||
|
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
||||||
|
if not prompt_file:
|
||||||
|
return []
|
||||||
|
if not os.path.exists(prompt_file):
|
||||||
|
rank_print(
|
||||||
|
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
with open(prompt_file, "r") as pf:
|
||||||
|
return pf.readlines()
|
||||||
|
|
||||||
|
|
||||||
|
def _save_profile_trace_results(profiler, filename):
|
||||||
|
parent_dir = os.path.dirname(os.path.abspath(filename))
|
||||||
|
os.makedirs(parent_dir, exist_ok=True)
|
||||||
|
profiler.export_chrome_trace(filename)
|
||||||
|
print(
|
||||||
|
profiler.key_averages(group_by_input_shape=True).table(
|
||||||
|
sort_by="self_cpu_time_total"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def correctness_test(
|
def correctness_test(
|
||||||
server_args,
|
server_args,
|
||||||
port_args,
|
port_args,
|
||||||
@@ -298,7 +343,10 @@ def correctness_test(
|
|||||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||||
|
input_ids, reqs = prepare_inputs_for_correctness_test(
|
||||||
|
bench_args, tokenizer, custom_prompts
|
||||||
|
)
|
||||||
rank_print(f"\n{input_ids=}\n")
|
rank_print(f"\n{input_ids=}\n")
|
||||||
|
|
||||||
if bench_args.cut_len > 0:
|
if bench_args.cut_len > 0:
|
||||||
@@ -344,6 +392,7 @@ def latency_test_run_once(
|
|||||||
device,
|
device,
|
||||||
log_decode_step,
|
log_decode_step,
|
||||||
profile,
|
profile,
|
||||||
|
profile_record_shapes,
|
||||||
profile_filename_prefix,
|
profile_filename_prefix,
|
||||||
):
|
):
|
||||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||||
@@ -374,6 +423,7 @@ def latency_test_run_once(
|
|||||||
torch.profiler.ProfilerActivity.CUDA,
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
],
|
],
|
||||||
with_stack=True,
|
with_stack=True,
|
||||||
|
record_shapes=profile_record_shapes,
|
||||||
)
|
)
|
||||||
profiler.start()
|
profiler.start()
|
||||||
|
|
||||||
@@ -391,10 +441,30 @@ def latency_test_run_once(
|
|||||||
measurement_results["prefill_latency"] = prefill_latency
|
measurement_results["prefill_latency"] = prefill_latency
|
||||||
measurement_results["prefill_throughput"] = throughput
|
measurement_results["prefill_throughput"] = throughput
|
||||||
|
|
||||||
|
if profile:
|
||||||
|
profiler.stop()
|
||||||
|
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||||
|
_save_profile_trace_results(profiler, profile_filename)
|
||||||
|
rank_print(
|
||||||
|
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
||||||
|
)
|
||||||
|
|
||||||
# Decode
|
# Decode
|
||||||
decode_latencies = []
|
decode_latencies = []
|
||||||
for i in range(output_len - 1):
|
for i in range(output_len - 1):
|
||||||
synchronize(device)
|
synchronize(device)
|
||||||
|
if profile and i == output_len / 2:
|
||||||
|
profiler = None
|
||||||
|
profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
torch.profiler.ProfilerActivity.CUDA,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
record_shapes=profile_record_shapes,
|
||||||
|
)
|
||||||
|
profiler.start()
|
||||||
|
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||||
synchronize(device)
|
synchronize(device)
|
||||||
@@ -407,13 +477,13 @@ def latency_test_run_once(
|
|||||||
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||||
)
|
)
|
||||||
|
|
||||||
if profile:
|
if profile and i == output_len / 2:
|
||||||
profiler.stop()
|
profiler.stop()
|
||||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
|
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||||
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
|
_save_profile_trace_results(profiler, profile_filename)
|
||||||
os.makedirs(parent_dir, exist_ok=True)
|
rank_print(
|
||||||
profiler.export_chrome_trace(profile_filename)
|
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
||||||
rank_print(f"torch profiler chrome trace saved to {profile_filename}")
|
)
|
||||||
|
|
||||||
# Record decode timing from 2nd output
|
# Record decode timing from 2nd output
|
||||||
if output_len > 1:
|
if output_len > 1:
|
||||||
@@ -469,17 +539,42 @@ def latency_test(
|
|||||||
server_args.device,
|
server_args.device,
|
||||||
log_decode_step=0,
|
log_decode_step=0,
|
||||||
profile=False,
|
profile=False,
|
||||||
|
profile_record_shapes=False,
|
||||||
profile_filename_prefix="", # not used
|
profile_filename_prefix="", # not used
|
||||||
)
|
)
|
||||||
|
|
||||||
rank_print("Benchmark ...")
|
rank_print("Benchmark ...")
|
||||||
|
|
||||||
|
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||||
|
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
||||||
|
custom_input_len = len(custom_inputs)
|
||||||
|
|
||||||
# Run the sweep
|
# Run the sweep
|
||||||
result_list = []
|
result_list = []
|
||||||
for bs, il, ol in itertools.product(
|
for bs, il, ol in itertools.product(
|
||||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||||
):
|
):
|
||||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
bs_aligned_inputs = []
|
||||||
|
if custom_inputs:
|
||||||
|
if custom_input_len == bs:
|
||||||
|
bs_aligned_inputs = custom_inputs
|
||||||
|
elif custom_input_len > bs:
|
||||||
|
rank_print(
|
||||||
|
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
||||||
|
f"Using the first {bs} prompts."
|
||||||
|
)
|
||||||
|
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
||||||
|
else:
|
||||||
|
rank_print(
|
||||||
|
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
||||||
|
f"Pad to the desired batch_size with the last prompt."
|
||||||
|
)
|
||||||
|
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
||||||
|
bs_aligned_inputs.extend(
|
||||||
|
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
||||||
ret = latency_test_run_once(
|
ret = latency_test_run_once(
|
||||||
bench_args.run_name,
|
bench_args.run_name,
|
||||||
model_runner,
|
model_runner,
|
||||||
@@ -491,6 +586,7 @@ def latency_test(
|
|||||||
server_args.device,
|
server_args.device,
|
||||||
bench_args.log_decode_step,
|
bench_args.log_decode_step,
|
||||||
bench_args.profile if tp_rank == 0 else None,
|
bench_args.profile if tp_rank == 0 else None,
|
||||||
|
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
||||||
bench_args.profile_filename_prefix,
|
bench_args.profile_filename_prefix,
|
||||||
)
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
|
|||||||
Reference in New Issue
Block a user