diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 341cce09a..8a46d2318 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -149,6 +149,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): gpu_mem = get_device_memory_capacity() if gpu_mem is not None and gpu_mem > 96 * 1024: capture_bs += list(range(160, 257, 8)) + if gpu_mem is not None and gpu_mem > 180 * 1000: + capture_bs += list(range(256, 528, 16)) if max(capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 9c843a315..b84a1ddf4 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -260,7 +260,9 @@ class ServerArgs: self.mem_fraction_static = 0.88 else: self.mem_fraction_static = 0.88 - if gpu_mem is not None and gpu_mem > 96 * 1024: + if gpu_mem is not None and gpu_mem > 180 * 1000: + self.mem_fraction_static = 0.79 + elif gpu_mem is not None and gpu_mem > 96 * 1024: mem_fraction = self.mem_fraction_static # 15 GB + additional 3GB for cuda graph reserve_mem = 1024 * 18 @@ -277,7 +279,9 @@ class ServerArgs: # Set chunked prefill size, which depends on the gpu memory capacity if self.chunked_prefill_size is None: - if gpu_mem is not None and gpu_mem < 25_000: + if gpu_mem is not None and gpu_mem > 180_000: + self.chunked_prefill_size = 16384 + elif gpu_mem is not None and gpu_mem < 25_000: self.chunked_prefill_size = 2048 elif self.disaggregation_mode != "null": self.chunked_prefill_size = 16384