Fix torch.compile cacheing (#5259)

Co-authored-by: zhyncs <me@zhyncs.com>
This commit is contained in:
Richard Zou
2025-04-10 21:08:45 -04:00
committed by GitHub
parent a222945df2
commit a879811c4b
2 changed files with 17 additions and 1 deletions

View File

@@ -14,6 +14,7 @@
from typing import Callable, Union
import torch
from packaging import version
from torch.multiprocessing import reductions
@@ -69,3 +70,13 @@ def _device_from_maybe_uuid(device_maybe_uuid: Union[int, str]) -> int:
def _modify_tuple(t, index: int, modifier: Callable):
return *t[:index], modifier(t[index]), *t[index + 1 :]
def monkey_patch_torch_compile():
if version.parse(torch.__version__) < version.parse("2.8.0"):
# These things are cacheable by torch.compile. torch.compile just doesn't know it.
# This was fixed in PyTorch 2.8, but until then, we monkey patch.
import torch._higher_order_ops.auto_functionalize as af
af.auto_functionalized_v2._cacheable = True
af.auto_functionalized._cacheable = True