diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index b7dbdb16d..1f9212834 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -31,11 +31,6 @@ _is_hip = is_hip() logger = logging.getLogger(__name__) -# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. -logger.warning( - "The following error message 'operation scheduled before its operands' can be ignored." -) - _MIN_BLOCK_KV = 32 diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 51a083eb6..2ab5ee385 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -153,7 +153,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): gpu_mem = get_device_memory_capacity() if gpu_mem is not None: - if gpu_mem > 90 * 1024: # H200 + if gpu_mem > 90 * 1024: # H200, H20 capture_bs += list(range(160, 257, 8)) if gpu_mem > 160 * 1000: # B200, MI300 capture_bs += list(range(256, 513, 16)) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e1fc3be03..639ed96bc 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -254,51 +254,72 @@ class ServerArgs: gpu_mem = get_device_memory_capacity(self.device) - # Set mem fraction static, which depends on the tensor parallelism size + # Set mem fraction static if self.mem_fraction_static is None: - parallel_size = self.tp_size * self.pp_size - if gpu_mem is not None and gpu_mem <= 81920: - if parallel_size >= 16: - self.mem_fraction_static = 0.79 - elif parallel_size >= 8: - self.mem_fraction_static = 0.81 - elif parallel_size >= 4: - self.mem_fraction_static = 0.85 - elif parallel_size >= 2: - self.mem_fraction_static = 0.87 + if gpu_mem is not None: + # GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers + # mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity. + + # We want mem_fraction_static to be as large as possible but still has enough room + # for activations and cuda graph buffers. We use the following heuristic to + # compute the needed size for activations and cuda graph buffers: + # - The size of the activation depends on the chunked_prefill_size and model size. + # - The size of cuda graph buffers depends on the cuda graph capture range and model size. + # For GPUs with more memory, we use a larger chunked_prefill_size and + # capture more cuda graphs, so they need to reserve more memory. + parallel_size = self.tp_size * self.pp_size + + if gpu_mem < 20 * 1024: + # T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8) + reserved_mem = (2.8 + parallel_size / 10) * 1024 + elif gpu_mem < 35 * 1024: + # A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8) + reserved_mem = (2.8 + parallel_size / 10) * 1024 + elif gpu_mem < 90 * 1024: + # H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160) + reserved_mem = (9.5 + parallel_size / 2) * 1024 + elif gpu_mem < 100 * 1024: + # H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256) + reserved_mem = (12 + parallel_size / 2) * 1024 + elif gpu_mem < 160 * 1024: + # H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256) + reserved_mem = (12 + parallel_size / 2) * 1024 else: - self.mem_fraction_static = 0.88 + # B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512) + reserved_mem = 32 * 1024 + + if self.speculative_algorithm is not None: + # draft model and larger cuda graph buffers + reserved_mem += 2 * 1024 + if self.enable_dp_attention: + reserved_mem += 4 * 1024 + + self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3) else: self.mem_fraction_static = 0.88 - if gpu_mem is not None and gpu_mem > 180 * 1000 and is_cuda(): - self.mem_fraction_static = 0.79 - elif gpu_mem is not None and gpu_mem > 96 * 1024: - mem_fraction = self.mem_fraction_static - # 15 GB + additional 3GB for cuda graph - reserve_mem = 1024 * 18 - # need reserve more memory for spec cuda graph - if self.speculative_algorithm is not None: - reserve_mem = 1024 * 20 - self.mem_fraction_static = min( - mem_fraction + 48 * 1024 * (1 - mem_fraction) / gpu_mem, - (gpu_mem - reserve_mem) / gpu_mem, - ) - else: - if self.speculative_algorithm is not None: - self.mem_fraction_static *= 0.95 # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: - if gpu_mem is not None and gpu_mem > 180_000: - self.chunked_prefill_size = 16384 - elif gpu_mem is not None and gpu_mem < 25_000: - self.chunked_prefill_size = 2048 - elif self.disaggregation_mode != "null": - self.chunked_prefill_size = 16384 + if gpu_mem is not None: + if gpu_mem < 35 * 1024: # A10, L40, 4090 + self.chunked_prefill_size = 2048 + elif gpu_mem < 160 * 1024: # H100, H200, A100, H20 + self.chunked_prefill_size = 8192 + else: # B200, MI300 + self.chunked_prefill_size = 16384 else: - self.chunked_prefill_size = 8192 + self.chunked_prefill_size = 4096 assert self.chunked_prefill_size % self.page_size == 0 + # Set cuda graph max batch size + if self.cuda_graph_max_bs is None: + # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. + if gpu_mem is not None and gpu_mem < 35 * 1024: + if self.tp_size < 4: + self.cuda_graph_max_bs = 8 + else: + self.cuda_graph_max_bs = 80 + assert self.moe_dense_tp_size in { 1, None, @@ -316,15 +337,6 @@ class ServerArgs: ) self.page_size = 128 - # Set cuda graph max batch size - if self.cuda_graph_max_bs is None: - # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. - if gpu_mem is not None and gpu_mem < 25_000: - if self.tp_size < 4: - self.cuda_graph_max_bs = 8 - else: - self.cuda_graph_max_bs = 80 - # Set kernel backends for hpu device if self.device == "hpu": self.attention_backend = "torch_native"