[Minor] Move torch.compile patch to a better place (#5397)

This commit is contained in:
Lianmin Zheng
2025-04-15 18:37:07 -07:00
committed by GitHub
parent b64b88e738
commit 0769b14bf9
2 changed files with 4 additions and 6 deletions

View File

@@ -34,6 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
@@ -108,6 +109,8 @@ def set_torch_compile_config():
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
monkey_patch_torch_compile()
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args

View File

@@ -64,10 +64,7 @@ from sglang.srt.model_loader.loader import (
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.patch_torch import (
monkey_patch_torch_compile,
monkey_patch_torch_reductions,
)
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -94,8 +91,6 @@ logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
monkey_patch_torch_compile()
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""