Allow skipping warmup in bench_offline_throughput.py (#2103)
This commit is contained in:
@@ -57,6 +57,7 @@ class BenchArgs:
|
|||||||
disable_ignore_eos: bool = False
|
disable_ignore_eos: bool = False
|
||||||
extra_request_body: Optional[str] = None
|
extra_request_body: Optional[str] = None
|
||||||
seed: int = 1
|
seed: int = 1
|
||||||
|
skip_warmup: bool = False
|
||||||
do_not_exit: bool = False
|
do_not_exit: bool = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -152,6 +153,11 @@ class BenchArgs:
|
|||||||
"additional generate params like sampling params.",
|
"additional generate params like sampling params.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip-warmup",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip the warmup batches.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--do-not-exit",
|
"--do-not-exit",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -261,14 +267,15 @@ def throughput_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Warm up
|
# Warm up
|
||||||
logging.info("\nWarmup...")
|
if not bench_args.skip_warmup:
|
||||||
throughput_test_once(
|
logging.info("\nWarmup...")
|
||||||
backend_name=bench_args.backend,
|
throughput_test_once(
|
||||||
backend=backend,
|
backend_name=bench_args.backend,
|
||||||
reqs=warmup_requests,
|
backend=backend,
|
||||||
ignore_eos=not bench_args.disable_ignore_eos,
|
reqs=warmup_requests,
|
||||||
extra_request_body=extra_request_body,
|
ignore_eos=not bench_args.disable_ignore_eos,
|
||||||
)
|
extra_request_body=extra_request_body,
|
||||||
|
)
|
||||||
|
|
||||||
logging.info("\nBenchmark...")
|
logging.info("\nBenchmark...")
|
||||||
result = throughput_test_once(
|
result = throughput_test_once(
|
||||||
|
|||||||
@@ -156,9 +156,6 @@ class TpModelWorkerClient:
|
|||||||
return logits_output, next_token_ids
|
return logits_output, next_token_ids
|
||||||
|
|
||||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||||
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
|
|
||||||
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
||||||
sampling_info = model_worker_batch.sampling_info
|
sampling_info = model_worker_batch.sampling_info
|
||||||
sampling_info.update_penalties()
|
sampling_info.update_penalties()
|
||||||
@@ -169,6 +166,9 @@ class TpModelWorkerClient:
|
|||||||
linear_penalties=sampling_info.linear_penalties,
|
linear_penalties=sampling_info.linear_penalties,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
# Push a new batch to the queue
|
# Push a new batch to the queue
|
||||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user