[Optimization][Perf] Disable the GC during CUDA graph capture to speed up by up to 3x (#8577)
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
@@ -75,6 +76,24 @@ def model_capture_mode():
|
||||
is_capture_mode = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def freeze_gc(enable_cudagraph_gc: bool):
|
||||
"""
|
||||
Optimize garbage collection during CUDA graph capture.
|
||||
Clean up, then freeze all remaining objects from being included
|
||||
in future collections if GC is disabled during capture.
|
||||
"""
|
||||
gc.collect()
|
||||
should_freeze = not enable_cudagraph_gc
|
||||
if should_freeze:
|
||||
gc.freeze()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if should_freeze:
|
||||
gc.unfreeze()
|
||||
|
||||
|
||||
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
||||
for sub in model._modules.values():
|
||||
if isinstance(sub, CustomOp):
|
||||
@@ -423,7 +442,12 @@ class CudaGraphRunner:
|
||||
record_shapes=True,
|
||||
)
|
||||
|
||||
with graph_capture() as graph_capture_context:
|
||||
# Trigger CUDA graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
with freeze_gc(
|
||||
self.model_runner.server_args.enable_cudagraph_gc
|
||||
), graph_capture() as graph_capture_context:
|
||||
with profile_context as prof:
|
||||
self.stream = graph_capture_context.stream
|
||||
avail_mem = get_available_gpu_memory(
|
||||
|
||||
@@ -215,6 +215,7 @@ class ServerArgs:
|
||||
disable_cuda_graph: bool = False
|
||||
disable_cuda_graph_padding: bool = False
|
||||
enable_profile_cuda_graph: bool = False
|
||||
enable_cudagraph_gc: bool = False
|
||||
enable_nccl_nvls: bool = False
|
||||
enable_tokenizer_batch_encode: bool = False
|
||||
disable_outlines_disk_cache: bool = False
|
||||
@@ -1545,6 +1546,11 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable profiling of cuda graph capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-cudagraph-gc",
|
||||
action="store_true",
|
||||
help="Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-nccl-nvls",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user