Fix torch.compile cacheing (#5259)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Richard Zou
2025-04-10 21:08:45 -04:00
committed by GitHub
parent a222945df2
commit a879811c4b
2 changed files with 17 additions and 1 deletions

View File

@@ -64,7 +64,10 @@ from sglang.srt.model_loader.loader import (
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.patch_torch import (
monkey_patch_torch_compile,
monkey_patch_torch_reductions,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -88,6 +91,8 @@ logger = logging.getLogger(__name__)
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
monkey_patch_torch_compile()
class ModelRunner:
"""ModelRunner runs the forward passes of the models."""