From d81ac4434e18b9a199aa2dea923a89c31a34d094 Mon Sep 17 00:00:00 2001 From: HAI Date: Tue, 11 Feb 2025 11:04:38 -0800 Subject: [PATCH] MI30x: More graph captures for larger batch sizes and concurrencies (#3420) --- python/sglang/srt/model_executor/cuda_graph_runner.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 249bf82bd..db103162f 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,6 +33,9 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) +from sglang.srt.utils import is_hip + +is_hip_ = is_hip() if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -129,6 +132,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): if bs <= model_runner.req_to_token_pool.size and bs <= server_args.cuda_graph_max_bs ] + if is_hip_: + capture_bs += [i * 8 for i in range(21, 33)] compile_bs = ( [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] if server_args.enable_torch_compile