Provide an argument to set the maximum batch size for cuda graph (#1809)

This commit is contained in:
Lianmin Zheng
2024-10-26 15:09:33 -07:00
committed by GitHub
parent 9d6fb08457
commit 2b80978859
4 changed files with 25 additions and 10 deletions

View File

@@ -113,12 +113,15 @@ class CudaGraphRunner:
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
# Batch sizes to capture
if self.model_runner.server_args.disable_cuda_graph_padding:
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
bs
for bs in self.capture_bs
if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.max_cuda_graph_bs
]
self.compile_bs = (
[