From 0769b14bf946811d93122f8023e45b24b50a25e9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 15 Apr 2025 18:37:07 -0700 Subject: [PATCH] [Minor] Move torch.compile patch to a better place (#5397) --- python/sglang/srt/model_executor/cuda_graph_runner.py | 3 +++ python/sglang/srt/model_executor/model_runner.py | 7 +------ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index bc2c4abb0..c71cae07a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -34,6 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) +from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.utils import get_available_gpu_memory, is_hip _is_hip = is_hip() @@ -108,6 +109,8 @@ def set_torch_compile_config(): if hasattr(torch._dynamo.config, "cache_size_limit"): torch._dynamo.config.cache_size_limit = 1024 + monkey_patch_torch_compile() + def get_batch_sizes_to_capture(model_runner: ModelRunner): server_args = model_runner.server_args diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 41a245a10..fd84339d7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -64,10 +64,7 @@ from sglang.srt.model_loader.loader import ( ) from sglang.srt.model_loader.utils import set_default_torch_dtype from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.patch_torch import ( - monkey_patch_torch_compile, - monkey_patch_torch_reductions, -) +from sglang.srt.patch_torch import monkey_patch_torch_reductions from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -94,8 +91,6 @@ logger = logging.getLogger(__name__) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 -monkey_patch_torch_compile() - class ModelRunner: """ModelRunner runs the forward passes of the models."""