[fix] reduce dp capture bs (#5634)
Co-authored-by: alcanerian <alcanerian@gmail.com>
This commit is contained in:
@@ -134,7 +134,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
)
|
||||
|
||||
gpu_mem = get_device_memory_capacity()
|
||||
if gpu_mem is not None and gpu_mem > 81920:
|
||||
# Batch size of each rank will not become so large when DP is on
|
||||
if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
|
||||
capture_bs += list(range(160, 257, 8))
|
||||
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
|
||||
Reference in New Issue
Block a user