Fix torch.compile cacheing (#5259)
Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
@@ -64,7 +64,10 @@ from sglang.srt.model_loader.loader import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
from sglang.srt.patch_torch import (
|
||||||
|
monkey_patch_torch_compile,
|
||||||
|
monkey_patch_torch_reductions,
|
||||||
|
)
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
@@ -88,6 +91,8 @@ logger = logging.getLogger(__name__)
|
|||||||
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
||||||
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
||||||
|
|
||||||
|
monkey_patch_torch_compile()
|
||||||
|
|
||||||
|
|
||||||
class ModelRunner:
|
class ModelRunner:
|
||||||
"""ModelRunner runs the forward passes of the models."""
|
"""ModelRunner runs the forward passes of the models."""
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from packaging import version
|
||||||
from torch.multiprocessing import reductions
|
from torch.multiprocessing import reductions
|
||||||
|
|
||||||
|
|
||||||
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
|
|||||||
|
|
||||||
def _modify_tuple(t, index: int, modifier: Callable):
|
def _modify_tuple(t, index: int, modifier: Callable):
|
||||||
return *t[:index], modifier(t[index]), *t[index + 1 :]
|
return *t[:index], modifier(t[index]), *t[index + 1 :]
|
||||||
|
|
||||||
|
|
||||||
|
def monkey_patch_torch_compile():
|
||||||
|
if version.parse(torch.__version__) < version.parse("2.8.0"):
|
||||||
|
# These things are cacheable by torch.compile. torch.compile just doesn't know it.
|
||||||
|
# This was fixed in PyTorch 2.8, but until then, we monkey patch.
|
||||||
|
import torch._higher_order_ops.auto_functionalize as af
|
||||||
|
|
||||||
|
af.auto_functionalized_v2._cacheable = True
|
||||||
|
af.auto_functionalized._cacheable = True
|
||||||
|
|||||||
Reference in New Issue
Block a user