Restruct gpu_memory_settings in a unify function and relax max_cuda_graph_bs (#10372)
Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com> Co-authored-by: sglang-bot <sglangbot@gmail.com>
This commit is contained in:
@@ -167,29 +167,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
server_args = model_runner.server_args
|
||||
capture_bs = server_args.cuda_graph_bs
|
||||
|
||||
if capture_bs is None:
|
||||
if server_args.speculative_algorithm is None:
|
||||
if server_args.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
||||
else:
|
||||
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
||||
else:
|
||||
# Since speculative decoding requires more cuda graph memory, we
|
||||
# capture less.
|
||||
capture_bs = (
|
||||
list(range(1, 9))
|
||||
+ list(range(10, 33, 2))
|
||||
+ list(range(40, 65, 8))
|
||||
+ list(range(80, 161, 16))
|
||||
)
|
||||
|
||||
gpu_mem = get_device_memory_capacity()
|
||||
if gpu_mem is not None:
|
||||
if gpu_mem > 90 * 1024: # H200, H20
|
||||
capture_bs += list(range(160, 257, 8))
|
||||
if gpu_mem > 160 * 1000: # B200, MI300
|
||||
capture_bs += list(range(256, 513, 16))
|
||||
|
||||
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
||||
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
||||
# is very small. We add more values here to make sure we capture the maximum bs.
|
||||
@@ -205,12 +182,6 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
|
||||
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
||||
|
||||
if server_args.cuda_graph_max_bs:
|
||||
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
||||
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
||||
capture_bs += list(
|
||||
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
||||
)
|
||||
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
||||
capture_bs = list(sorted(set(capture_bs)))
|
||||
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
||||
|
||||
@@ -450,12 +450,8 @@ class ServerArgs:
|
||||
# Get GPU memory capacity, which is a common dependency for several configuration steps.
|
||||
gpu_mem = get_device_memory_capacity(self.device)
|
||||
|
||||
# Handle memory-related configurations.
|
||||
self._handle_mem_fraction_static(gpu_mem)
|
||||
self._handle_chunked_prefill_size(gpu_mem)
|
||||
|
||||
# Handle CUDA graph settings.
|
||||
self._handle_cuda_graph_max_bs(gpu_mem)
|
||||
# Handle memory-related, chunked prefill, and CUDA graph batch size configurations.
|
||||
self._handle_gpu_memory_settings(gpu_mem)
|
||||
|
||||
# Handle device-specific backends.
|
||||
self._handle_hpu_backends()
|
||||
@@ -526,7 +522,12 @@ class ServerArgs:
|
||||
if self.random_seed is None:
|
||||
self.random_seed = random.randint(0, 1 << 30)
|
||||
|
||||
def _handle_mem_fraction_static(self, gpu_mem):
|
||||
def _handle_gpu_memory_settings(self, gpu_mem):
|
||||
"""
|
||||
Configure GPU memory-dependent settings including mem_fraction_static,
|
||||
chunked_prefill_size, cuda_graph_max_bs, and cuda_graph_bs.
|
||||
"""
|
||||
# Set mem fraction static
|
||||
if self.mem_fraction_static is None:
|
||||
if gpu_mem is not None:
|
||||
# GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
|
||||
@@ -544,18 +545,18 @@ class ServerArgs:
|
||||
if gpu_mem < 20 * 1024:
|
||||
# T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
|
||||
reserved_mem = (2.8 + parallel_size / 10) * 1024
|
||||
elif gpu_mem < 35 * 1024:
|
||||
# A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
|
||||
elif gpu_mem < 50 * 1024:
|
||||
# A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 16 if tp < 4 else 80)
|
||||
reserved_mem = (2.8 + parallel_size / 10) * 1024
|
||||
elif gpu_mem < 90 * 1024:
|
||||
# H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
|
||||
reserved_mem = (9.5 + parallel_size / 2) * 1024
|
||||
# H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 256 if tp < 4 else 512)
|
||||
reserved_mem = (12 + parallel_size / 2) * 1024
|
||||
elif gpu_mem < 100 * 1024:
|
||||
# H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
|
||||
reserved_mem = (12 + parallel_size / 2) * 1024
|
||||
# H20. (chunked_prefill_size 8k, cuda_graph_max_bs 512)
|
||||
reserved_mem = (15 + parallel_size / 2) * 1024
|
||||
elif gpu_mem < 160 * 1024:
|
||||
# H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
|
||||
reserved_mem = (12 + parallel_size / 2) * 1024
|
||||
# H200. (chunked_prefill_size 8k, cuda_graph_max_bs 512)
|
||||
reserved_mem = (15 + parallel_size / 2) * 1024
|
||||
else:
|
||||
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
|
||||
reserved_mem = 32 * 1024
|
||||
@@ -575,36 +576,86 @@ class ServerArgs:
|
||||
else:
|
||||
self.mem_fraction_static = 0.88
|
||||
|
||||
# Lazy init to avoid circular import.
|
||||
# Lazy init to avoid circular import
|
||||
# Multimodal models need more memory for the image processor
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
|
||||
model_config = ModelConfig.from_server_args(self)
|
||||
if model_config.is_multimodal:
|
||||
self.adjust_mem_fraction_for_vlm(model_config)
|
||||
|
||||
def _handle_chunked_prefill_size(self, gpu_mem):
|
||||
# Set chunked prefill size, which depends on the gpu memory capacity
|
||||
if self.chunked_prefill_size is None:
|
||||
if gpu_mem is not None:
|
||||
# A10, L40, 4090
|
||||
if gpu_mem < 35 * 1024:
|
||||
if gpu_mem < 50 * 1024: # T4, 4080, A10, L40, 4090, 5090
|
||||
self.chunked_prefill_size = 2048
|
||||
# H100, H200, A100, H20
|
||||
elif gpu_mem < 160 * 1024:
|
||||
elif gpu_mem < 160 * 1024: # H100, H200, A100, H20
|
||||
self.chunked_prefill_size = 8192
|
||||
# B200, MI300
|
||||
else:
|
||||
else: # B200, MI300
|
||||
self.chunked_prefill_size = 16384
|
||||
else:
|
||||
self.chunked_prefill_size = 4096
|
||||
|
||||
def _handle_cuda_graph_max_bs(self, gpu_mem):
|
||||
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
||||
# Set cuda graph max batch size and cuda graph batch sizes
|
||||
if self.cuda_graph_max_bs is None:
|
||||
if gpu_mem is not None and gpu_mem < 35 * 1024:
|
||||
if self.tp_size < 4:
|
||||
if gpu_mem is not None:
|
||||
if gpu_mem < 20 * 1024:
|
||||
# T4, 4080
|
||||
self.cuda_graph_max_bs = 8
|
||||
elif gpu_mem < 50 * 1024:
|
||||
# A10, L40, 4090, 5090
|
||||
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance.
|
||||
# However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs
|
||||
# from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
||||
if self.tp_size < 4:
|
||||
self.cuda_graph_max_bs = 16
|
||||
else:
|
||||
self.cuda_graph_max_bs = 80
|
||||
elif gpu_mem < 90 * 1024:
|
||||
# H100, A100
|
||||
if self.tp_size < 4:
|
||||
self.cuda_graph_max_bs = 256
|
||||
else:
|
||||
self.cuda_graph_max_bs = 512
|
||||
else:
|
||||
self.cuda_graph_max_bs = 80
|
||||
# H20, H200, B200, MI300
|
||||
self.cuda_graph_max_bs = 512
|
||||
else:
|
||||
# Default fallback
|
||||
self.cuda_graph_max_bs = 160
|
||||
|
||||
if self.cuda_graph_bs is None:
|
||||
self.cuda_graph_bs = self._generate_cuda_graph_batch_sizes()
|
||||
|
||||
def _generate_cuda_graph_batch_sizes(self):
|
||||
"""
|
||||
Generate the list of batch sizes for CUDA graph capture based on cuda_graph_max_bs.
|
||||
This integrates the logic from cuda_graph_runner.py.
|
||||
"""
|
||||
# Handle disable_cuda_graph_padding as the first condition for both spec and non-spec
|
||||
if self.disable_cuda_graph_padding:
|
||||
capture_bs = list(range(1, self.cuda_graph_max_bs + 1))
|
||||
elif self.speculative_algorithm is None:
|
||||
# Normal case: [1, 2, 4, 8, 12] + list(range(16, 257, 8)) + list(range(272, 512, 16)) + list(range(512, cuda_graph_max_bs + 1))
|
||||
capture_bs = (
|
||||
[1, 2, 4, 8, 12]
|
||||
+ list(range(16, 257, 8))
|
||||
+ list(range(272, 512, 16))
|
||||
+ list(range(512, self.cuda_graph_max_bs + 1))
|
||||
)
|
||||
else:
|
||||
# Spec decoding case: list(range(1, 9, 1)) + list(range(10, 33, 2)) + list(range(40, 64, 4)) + list(range(72, 257, 8))
|
||||
capture_bs = (
|
||||
list(range(1, 9, 1))
|
||||
+ list(range(10, 33, 2))
|
||||
+ list(range(40, 64, 4))
|
||||
+ list(range(72, 257, 8))
|
||||
+ list(range(272, self.cuda_graph_max_bs + 1, 16))
|
||||
)
|
||||
|
||||
capture_bs = [bs for bs in capture_bs if bs <= self.cuda_graph_max_bs]
|
||||
|
||||
return capture_bs
|
||||
|
||||
def _handle_hpu_backends(self):
|
||||
if self.device == "hpu":
|
||||
|
||||
Reference in New Issue
Block a user