Enhancements for bench_one_batch (#8703)
Co-authored-by: root <root@gnr630186.jf.intel.com>
This commit is contained in:
@@ -43,6 +43,7 @@ I'm going to the park
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import json
|
||||
@@ -84,12 +85,14 @@ class BenchArgs:
|
||||
batch_size: Tuple[int] = (1,)
|
||||
input_len: Tuple[int] = (1024,)
|
||||
output_len: Tuple[int] = (16,)
|
||||
prompt_filename: str = ""
|
||||
result_filename: str = "result.jsonl"
|
||||
correctness_test: bool = False
|
||||
# This is only used for correctness test
|
||||
cut_len: int = 4
|
||||
log_decode_step: int = 0
|
||||
profile: bool = False
|
||||
profile_record_shapes: bool = False
|
||||
profile_filename_prefix: str = "profile"
|
||||
|
||||
@staticmethod
|
||||
@@ -104,6 +107,9 @@ class BenchArgs:
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, nargs="+", default=BenchArgs.output_len
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-filename", type=str, default=BenchArgs.prompt_filename
|
||||
)
|
||||
parser.add_argument(
|
||||
"--result-filename", type=str, default=BenchArgs.result_filename
|
||||
)
|
||||
@@ -118,6 +124,11 @@ class BenchArgs:
|
||||
parser.add_argument(
|
||||
"--profile", action="store_true", help="Use Torch Profiler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-record-shapes",
|
||||
action="store_true",
|
||||
help="Record tensor shapes in profiling results.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile-filename-prefix",
|
||||
type=str,
|
||||
@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
|
||||
return model_runner, tokenizer
|
||||
|
||||
|
||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer):
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
|
||||
prompts = (
|
||||
custom_prompts
|
||||
if custom_prompts
|
||||
else [
|
||||
"The capital of France is",
|
||||
"The capital of the United Kindom is",
|
||||
"Today is a sunny day and I like",
|
||||
]
|
||||
)
|
||||
input_ids = [tokenizer.encode(p) for p in prompts]
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
|
||||
return reqs
|
||||
|
||||
|
||||
def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
|
||||
input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
||||
def prepare_synthetic_inputs_for_latency_test(
|
||||
batch_size, input_len, custom_inputs=None
|
||||
):
|
||||
input_ids = (
|
||||
custom_inputs
|
||||
if custom_inputs
|
||||
else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
max_new_tokens=BenchArgs.output_len,
|
||||
@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
||||
)
|
||||
|
||||
|
||||
def _read_prompts_from_file(prompt_file, rank_print):
|
||||
"""Read custom prompts from the file specified by `--prompt-filename`."""
|
||||
if not prompt_file:
|
||||
return []
|
||||
if not os.path.exists(prompt_file):
|
||||
rank_print(
|
||||
f"Custom prompt file {prompt_file} not found. Using default inputs..."
|
||||
)
|
||||
return []
|
||||
with open(prompt_file, "r") as pf:
|
||||
return pf.readlines()
|
||||
|
||||
|
||||
def _save_profile_trace_results(profiler, filename):
|
||||
parent_dir = os.path.dirname(os.path.abspath(filename))
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
profiler.export_chrome_trace(filename)
|
||||
print(
|
||||
profiler.key_averages(group_by_input_shape=True).table(
|
||||
sort_by="self_cpu_time_total"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def correctness_test(
|
||||
server_args,
|
||||
port_args,
|
||||
@@ -298,7 +343,10 @@ def correctness_test(
|
||||
model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
|
||||
custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||
input_ids, reqs = prepare_inputs_for_correctness_test(
|
||||
bench_args, tokenizer, custom_prompts
|
||||
)
|
||||
rank_print(f"\n{input_ids=}\n")
|
||||
|
||||
if bench_args.cut_len > 0:
|
||||
@@ -344,6 +392,7 @@ def latency_test_run_once(
|
||||
device,
|
||||
log_decode_step,
|
||||
profile,
|
||||
profile_record_shapes,
|
||||
profile_filename_prefix,
|
||||
):
|
||||
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
||||
@@ -374,6 +423,7 @@ def latency_test_run_once(
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
@@ -391,10 +441,30 @@ def latency_test_run_once(
|
||||
measurement_results["prefill_latency"] = prefill_latency
|
||||
measurement_results["prefill_throughput"] = throughput
|
||||
|
||||
if profile:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for prefill saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Decode
|
||||
decode_latencies = []
|
||||
for i in range(output_len - 1):
|
||||
synchronize(device)
|
||||
if profile and i == output_len / 2:
|
||||
profiler = None
|
||||
profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=profile_record_shapes,
|
||||
)
|
||||
profiler.start()
|
||||
|
||||
tic = time.perf_counter()
|
||||
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
||||
synchronize(device)
|
||||
@@ -407,13 +477,13 @@ def latency_test_run_once(
|
||||
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
||||
)
|
||||
|
||||
if profile:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
|
||||
parent_dir = os.path.dirname(os.path.abspath(profile_filename))
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
profiler.export_chrome_trace(profile_filename)
|
||||
rank_print(f"torch profiler chrome trace saved to {profile_filename}")
|
||||
if profile and i == output_len / 2:
|
||||
profiler.stop()
|
||||
profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
||||
_save_profile_trace_results(profiler, profile_filename)
|
||||
rank_print(
|
||||
f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
|
||||
)
|
||||
|
||||
# Record decode timing from 2nd output
|
||||
if output_len > 1:
|
||||
@@ -469,17 +539,42 @@ def latency_test(
|
||||
server_args.device,
|
||||
log_decode_step=0,
|
||||
profile=False,
|
||||
profile_record_shapes=False,
|
||||
profile_filename_prefix="", # not used
|
||||
)
|
||||
|
||||
rank_print("Benchmark ...")
|
||||
|
||||
custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
|
||||
custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
|
||||
custom_input_len = len(custom_inputs)
|
||||
|
||||
# Run the sweep
|
||||
result_list = []
|
||||
for bs, il, ol in itertools.product(
|
||||
bench_args.batch_size, bench_args.input_len, bench_args.output_len
|
||||
):
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
|
||||
bs_aligned_inputs = []
|
||||
if custom_inputs:
|
||||
if custom_input_len == bs:
|
||||
bs_aligned_inputs = custom_inputs
|
||||
elif custom_input_len > bs:
|
||||
rank_print(
|
||||
f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
|
||||
f"Using the first {bs} prompts."
|
||||
)
|
||||
bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
|
||||
else:
|
||||
rank_print(
|
||||
f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
|
||||
f"Pad to the desired batch_size with the last prompt."
|
||||
)
|
||||
bs_aligned_inputs = copy.deepcopy(custom_inputs)
|
||||
bs_aligned_inputs.extend(
|
||||
[bs_aligned_inputs[-1]] * (bs - custom_input_len)
|
||||
)
|
||||
|
||||
reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
|
||||
ret = latency_test_run_once(
|
||||
bench_args.run_name,
|
||||
model_runner,
|
||||
@@ -491,6 +586,7 @@ def latency_test(
|
||||
server_args.device,
|
||||
bench_args.log_decode_step,
|
||||
bench_args.profile if tp_rank == 0 else None,
|
||||
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
||||
bench_args.profile_filename_prefix,
|
||||
)
|
||||
if ret is not None:
|
||||
|
||||
Reference in New Issue
Block a user