234 lines
9.2 KiB
Python
234 lines
9.2 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import weakref
|
|
from collections import defaultdict
|
|
from contextlib import nullcontext
|
|
|
|
import torch
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._pytree import tree_map
|
|
from torch.utils.checkpoint import get_device_states, set_device_states
|
|
|
|
|
|
def _detach(x):
|
|
if isinstance(x, torch.Tensor):
|
|
return x.detach()
|
|
return x
|
|
|
|
|
|
class CachingTorchDispatchMode(TorchDispatchMode):
|
|
def __init__(self, policy_fn, storage):
|
|
self.policy_fn = policy_fn
|
|
self.storage = storage
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if self.policy_fn(func, *args, **kwargs):
|
|
out = func(*args, **kwargs)
|
|
out_detached = tree_map(_detach, out)
|
|
self.storage[func].append(out_detached)
|
|
return out
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
class CachedTorchDispatchMode(TorchDispatchMode):
|
|
def __init__(self, policy_fn, storage):
|
|
self.policy_fn = policy_fn
|
|
self.storage = storage
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
if self.policy_fn(func, *args, **kwargs):
|
|
# this is a basic guard if there are additional ops
|
|
# that were present in the second forward. An example
|
|
# is for ops like `detach`, which can be added if our
|
|
# policy is too loose
|
|
if self.storage[func]:
|
|
out = self.storage[func].pop(0)
|
|
return out
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
def _get_default_policy(allow_list=None):
|
|
_default_allow_list = [
|
|
"xformers.efficient_attention_forward_cutlass.default",
|
|
"xformers_flash.flash_fwd.default",
|
|
"aten.addmm.default",
|
|
"aten.mm.default",
|
|
]
|
|
if allow_list is None:
|
|
allow_list = _default_allow_list
|
|
|
|
def _default_policy(func, *args, **kwargs):
|
|
return str(func) in allow_list
|
|
|
|
return _default_policy
|
|
|
|
|
|
class VerboseTorchDispatchMode(TorchDispatchMode):
|
|
def __init__(self):
|
|
self.operators = []
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
self.operators.append(func)
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
def list_operators(function, *args, **kwargs):
|
|
"""
|
|
Returns the list of operators used inside `function` with
|
|
*args and **kwargs
|
|
"""
|
|
verbose_mode = VerboseTorchDispatchMode()
|
|
with verbose_mode:
|
|
function(*args, **kwargs)
|
|
return verbose_mode.operators
|
|
|
|
|
|
def checkpoint(function, *args, preserve_rng_state=True, policy_fn=None, **kwargs):
|
|
"""Checkpointining with custom policy function for selectively deciding
|
|
what to store and what to recompute
|
|
Args:
|
|
function: describes what to run in the forward pass of the model or
|
|
part of the model. It should also know how to handle the inputs
|
|
passed as the tuple. For example, in LSTM, if user passes
|
|
``(activation, hidden)``, :attr:`function` should correctly use the
|
|
first input as ``activation`` and the second input as ``hidden``
|
|
preserve_rng_state(bool, optional): Omit stashing and restoring
|
|
the RNG state during each checkpoint.
|
|
Default: ``True``
|
|
policy_fn(Union[List[Op], callable]): policy for deciding what to
|
|
store (instead of recompute). If it's a function, it should
|
|
be of form (func, *args, **kwargs) -> bool which indicates
|
|
if func outputs with *args and **kwargs should be stored or not.
|
|
Additionally, a list[Op] is also supported for easier cases.
|
|
The op should be in the format `torch.ops.***`, where the `***`
|
|
names of operators can be obtained with `list_operators`.
|
|
*args: Arguments to pass in to the given ``function``.
|
|
**kwargs: Keyword arguments to pass into the given ``function``.
|
|
"""
|
|
|
|
# Requires PyTorch 1.13 at least
|
|
from torch.utils.checkpoint import _get_autocast_kwargs
|
|
|
|
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
|
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
|
|
|
|
if preserve_rng_state:
|
|
fwd_cpu_state = torch.get_rng_state()
|
|
# Don't eagerly initialize the cuda context by accident.
|
|
# (If the user intends that the context is initialized later, within their
|
|
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
|
# we have no way to anticipate this will happen before we run the function.
|
|
# If they do so, we raise an error.)
|
|
had_cuda_in_fwd = False
|
|
if torch.cuda._initialized:
|
|
had_cuda_in_fwd = True
|
|
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
|
|
|
|
# Custom class to be able to take weak references
|
|
class Holder:
|
|
pass
|
|
|
|
# The Holder object for each of the saved object is saved directly on the
|
|
# SavedVariable and is cleared when reset_data() is called on it. We MUST make
|
|
# sure that this is the only object having an owning reference to ensure that
|
|
# the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
|
|
# data is cleared.
|
|
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
weak_holder_list = []
|
|
|
|
if policy_fn is None:
|
|
policy_fn = _get_default_policy()
|
|
elif isinstance(policy_fn, list):
|
|
policy_fn = _get_default_policy(policy_fn)
|
|
else:
|
|
assert callable(policy_fn), "policy_fn should be None, list or a callable"
|
|
|
|
temp_storage = defaultdict(list)
|
|
# assumption: grad_mode doesn't change inside function
|
|
if torch.is_grad_enabled():
|
|
caching_mode = CachingTorchDispatchMode(policy_fn, temp_storage)
|
|
else:
|
|
caching_mode = nullcontext()
|
|
cached_mode = CachedTorchDispatchMode(policy_fn, temp_storage)
|
|
|
|
def pack(x):
|
|
# TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
|
|
# size, device, ...) to catch certain cases of undeterministic behavior of the forward
|
|
res = Holder()
|
|
weak_holder_list.append(weakref.ref(res))
|
|
return res
|
|
|
|
def unpack(x):
|
|
unpack_counter = 0
|
|
if len(storage) == 0:
|
|
|
|
def inner_pack(inner):
|
|
nonlocal unpack_counter
|
|
unpack_counter += 1
|
|
# If the holder went out of scope, the SavedVariable is dead and so
|
|
# the value will never be read from the storage. Skip filling it.
|
|
if weak_holder_list[unpack_counter - 1]() is None:
|
|
return
|
|
# Use detach here to ensure we don't keep the temporary autograd
|
|
# graph created during the second forward
|
|
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
|
return
|
|
|
|
def inner_unpack(packed):
|
|
raise RuntimeError(
|
|
"You are calling backwards on a tensor that is never exposed. Please open an issue."
|
|
)
|
|
|
|
# Stash the surrounding rng state, and mimic the state that was
|
|
# present at this time during forward. Restore the surrounding state
|
|
# when we're done.
|
|
rng_devices = []
|
|
if preserve_rng_state and had_cuda_in_fwd:
|
|
rng_devices = fwd_gpu_devices
|
|
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
|
|
if preserve_rng_state:
|
|
torch.set_rng_state(fwd_cpu_state)
|
|
if had_cuda_in_fwd:
|
|
set_device_states(fwd_gpu_devices, fwd_gpu_states)
|
|
|
|
with cached_mode, torch.enable_grad(), torch.cuda.amp.autocast(
|
|
**gpu_autocast_kwargs
|
|
), torch.cpu.amp.autocast(
|
|
**cpu_autocast_kwargs
|
|
), torch.autograd.graph.saved_tensors_hooks(
|
|
inner_pack, inner_unpack
|
|
):
|
|
_unused = function(*args, **kwargs) # noqa: F841
|
|
|
|
if x not in storage:
|
|
raise RuntimeError(
|
|
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
|
" recomputation being triggered in between, this is not currently supported. Please"
|
|
" open an issue with details on your use case so that we can prioritize adding this."
|
|
)
|
|
|
|
return storage[x]
|
|
|
|
with caching_mode, torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
|
output = function(*args, **kwargs)
|
|
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
|
|
# Cuda was not initialized before running the forward, so we didn't
|
|
# stash the CUDA state.
|
|
raise RuntimeError(
|
|
"PyTorch's CUDA state was initialized in the forward pass "
|
|
"of a Checkpoint, which is not allowed. Please open an issue "
|
|
"if you need this feature."
|
|
)
|
|
|
|
return output
|