diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index eef7fba14..fb703255b 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -16,6 +16,7 @@ from __future__ import annotations import bisect +import gc import inspect import logging import os @@ -75,6 +76,24 @@ def model_capture_mode(): is_capture_mode = False +@contextmanager +def freeze_gc(enable_cudagraph_gc: bool): + """ + Optimize garbage collection during CUDA graph capture. + Clean up, then freeze all remaining objects from being included + in future collections if GC is disabled during capture. + """ + gc.collect() + should_freeze = not enable_cudagraph_gc + if should_freeze: + gc.freeze() + try: + yield + finally: + if should_freeze: + gc.unfreeze() + + def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): @@ -423,7 +442,12 @@ class CudaGraphRunner: record_shapes=True, ) - with graph_capture() as graph_capture_context: + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with freeze_gc( + self.model_runner.server_args.enable_cudagraph_gc + ), graph_capture() as graph_capture_context: with profile_context as prof: self.stream = graph_capture_context.stream avail_mem = get_available_gpu_memory( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 24ec434fb..4ba08973d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -215,6 +215,7 @@ class ServerArgs: disable_cuda_graph: bool = False disable_cuda_graph_padding: bool = False enable_profile_cuda_graph: bool = False + enable_cudagraph_gc: bool = False enable_nccl_nvls: bool = False enable_tokenizer_batch_encode: bool = False disable_outlines_disk_cache: bool = False @@ -1545,6 +1546,11 @@ class ServerArgs: action="store_true", help="Enable profiling of cuda graph capture.", ) + parser.add_argument( + "--enable-cudagraph-gc", + action="store_true", + help="Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process.", + ) parser.add_argument( "--enable-nccl-nvls", action="store_true",