Provide an argument to set the maximum batch size for cuda graph (#1809)
This commit is contained in:
@@ -113,12 +113,15 @@ class CudaGraphRunner:
|
||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||
|
||||
# Batch sizes to capture
|
||||
if self.model_runner.server_args.disable_cuda_graph_padding:
|
||||
if model_runner.server_args.disable_cuda_graph_padding:
|
||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||
else:
|
||||
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
|
||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
||||
self.capture_bs = [
|
||||
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
||||
bs
|
||||
for bs in self.capture_bs
|
||||
if bs <= model_runner.req_to_token_pool.size
|
||||
and bs <= model_runner.server_args.max_cuda_graph_bs
|
||||
]
|
||||
self.compile_bs = (
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user