[Minor] Move torch.compile patch to a better place (#5397)
This commit is contained in:
@@ -34,6 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
||||
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -108,6 +109,8 @@ def set_torch_compile_config():
|
||||
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
||||
torch._dynamo.config.cache_size_limit = 1024
|
||||
|
||||
monkey_patch_torch_compile()
|
||||
|
||||
|
||||
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
||||
server_args = model_runner.server_args
|
||||
|
||||
@@ -64,10 +64,7 @@ from sglang.srt.model_loader.loader import (
|
||||
)
|
||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.patch_torch import (
|
||||
monkey_patch_torch_compile,
|
||||
monkey_patch_torch_reductions,
|
||||
)
|
||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -94,8 +91,6 @@ logger = logging.getLogger(__name__)
|
||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||
|
||||
monkey_patch_torch_compile()
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
"""ModelRunner runs the forward passes of the models."""
|
||||
|
||||
Reference in New Issue
Block a user