First commit
This commit is contained in:
233
pkgs/xformers/checkpoint.py
Normal file
233
pkgs/xformers/checkpoint.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.checkpoint import get_device_states, set_device_states
|
||||
|
||||
|
||||
def _detach(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.detach()
|
||||
return x
|
||||
|
||||
|
||||
class CachingTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self, policy_fn, storage):
|
||||
self.policy_fn = policy_fn
|
||||
self.storage = storage
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if self.policy_fn(func, *args, **kwargs):
|
||||
out = func(*args, **kwargs)
|
||||
out_detached = tree_map(_detach, out)
|
||||
self.storage[func].append(out_detached)
|
||||
return out
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class CachedTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self, policy_fn, storage):
|
||||
self.policy_fn = policy_fn
|
||||
self.storage = storage
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if self.policy_fn(func, *args, **kwargs):
|
||||
# this is a basic guard if there are additional ops
|
||||
# that were present in the second forward. An example
|
||||
# is for ops like `detach`, which can be added if our
|
||||
# policy is too loose
|
||||
if self.storage[func]:
|
||||
out = self.storage[func].pop(0)
|
||||
return out
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_default_policy(allow_list=None):
|
||||
_default_allow_list = [
|
||||
"xformers.efficient_attention_forward_cutlass.default",
|
||||
"xformers_flash.flash_fwd.default",
|
||||
"aten.addmm.default",
|
||||
"aten.mm.default",
|
||||
]
|
||||
if allow_list is None:
|
||||
allow_list = _default_allow_list
|
||||
|
||||
def _default_policy(func, *args, **kwargs):
|
||||
return str(func) in allow_list
|
||||
|
||||
return _default_policy
|
||||
|
||||
|
||||
class VerboseTorchDispatchMode(TorchDispatchMode):
|
||||
def __init__(self):
|
||||
self.operators = []
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
self.operators.append(func)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def list_operators(function, *args, **kwargs):
|
||||
"""
|
||||
Returns the list of operators used inside `function` with
|
||||
*args and **kwargs
|
||||
"""
|
||||
verbose_mode = VerboseTorchDispatchMode()
|
||||
with verbose_mode:
|
||||
function(*args, **kwargs)
|
||||
return verbose_mode.operators
|
||||
|
||||
|
||||
def checkpoint(function, *args, preserve_rng_state=True, policy_fn=None, **kwargs):
|
||||
"""Checkpointining with custom policy function for selectively deciding
|
||||
what to store and what to recompute
|
||||
Args:
|
||||
function: describes what to run in the forward pass of the model or
|
||||
part of the model. It should also know how to handle the inputs
|
||||
passed as the tuple. For example, in LSTM, if user passes
|
||||
``(activation, hidden)``, :attr:`function` should correctly use the
|
||||
first input as ``activation`` and the second input as ``hidden``
|
||||
preserve_rng_state(bool, optional): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
Default: ``True``
|
||||
policy_fn(Union[List[Op], callable]): policy for deciding what to
|
||||
store (instead of recompute). If it's a function, it should
|
||||
be of form (func, *args, **kwargs) -> bool which indicates
|
||||
if func outputs with *args and **kwargs should be stored or not.
|
||||
Additionally, a list[Op] is also supported for easier cases.
|
||||
The op should be in the format `torch.ops.***`, where the `***`
|
||||
names of operators can be obtained with `list_operators`.
|
||||
*args: Arguments to pass in to the given ``function``.
|
||||
**kwargs: Keyword arguments to pass into the given ``function``.
|
||||
"""
|
||||
|
||||
# Requires PyTorch 1.13 at least
|
||||
from torch.utils.checkpoint import _get_autocast_kwargs
|
||||
|
||||
# Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
|
||||
gpu_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs()
|
||||
|
||||
if preserve_rng_state:
|
||||
fwd_cpu_state = torch.get_rng_state()
|
||||
# Don't eagerly initialize the cuda context by accident.
|
||||
# (If the user intends that the context is initialized later, within their
|
||||
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
|
||||
# we have no way to anticipate this will happen before we run the function.
|
||||
# If they do so, we raise an error.)
|
||||
had_cuda_in_fwd = False
|
||||
if torch.cuda._initialized:
|
||||
had_cuda_in_fwd = True
|
||||
fwd_gpu_devices, fwd_gpu_states = get_device_states(*args)
|
||||
|
||||
# Custom class to be able to take weak references
|
||||
class Holder:
|
||||
pass
|
||||
|
||||
# The Holder object for each of the saved object is saved directly on the
|
||||
# SavedVariable and is cleared when reset_data() is called on it. We MUST make
|
||||
# sure that this is the only object having an owning reference to ensure that
|
||||
# the Tensor stored in storage is deleted as soon as the corresponding SavedVariable
|
||||
# data is cleared.
|
||||
storage: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
||||
weak_holder_list = []
|
||||
|
||||
if policy_fn is None:
|
||||
policy_fn = _get_default_policy()
|
||||
elif isinstance(policy_fn, list):
|
||||
policy_fn = _get_default_policy(policy_fn)
|
||||
else:
|
||||
assert callable(policy_fn), "policy_fn should be None, list or a callable"
|
||||
|
||||
temp_storage = defaultdict(list)
|
||||
# assumption: grad_mode doesn't change inside function
|
||||
if torch.is_grad_enabled():
|
||||
caching_mode = CachingTorchDispatchMode(policy_fn, temp_storage)
|
||||
else:
|
||||
caching_mode = nullcontext()
|
||||
cached_mode = CachedTorchDispatchMode(policy_fn, temp_storage)
|
||||
|
||||
def pack(x):
|
||||
# TODO(varal7): Instead of returning abstract object, we can return things metadata (such as
|
||||
# size, device, ...) to catch certain cases of undeterministic behavior of the forward
|
||||
res = Holder()
|
||||
weak_holder_list.append(weakref.ref(res))
|
||||
return res
|
||||
|
||||
def unpack(x):
|
||||
unpack_counter = 0
|
||||
if len(storage) == 0:
|
||||
|
||||
def inner_pack(inner):
|
||||
nonlocal unpack_counter
|
||||
unpack_counter += 1
|
||||
# If the holder went out of scope, the SavedVariable is dead and so
|
||||
# the value will never be read from the storage. Skip filling it.
|
||||
if weak_holder_list[unpack_counter - 1]() is None:
|
||||
return
|
||||
# Use detach here to ensure we don't keep the temporary autograd
|
||||
# graph created during the second forward
|
||||
storage[weak_holder_list[unpack_counter - 1]()] = inner.detach()
|
||||
return
|
||||
|
||||
def inner_unpack(packed):
|
||||
raise RuntimeError(
|
||||
"You are calling backwards on a tensor that is never exposed. Please open an issue."
|
||||
)
|
||||
|
||||
# Stash the surrounding rng state, and mimic the state that was
|
||||
# present at this time during forward. Restore the surrounding state
|
||||
# when we're done.
|
||||
rng_devices = []
|
||||
if preserve_rng_state and had_cuda_in_fwd:
|
||||
rng_devices = fwd_gpu_devices
|
||||
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
|
||||
if preserve_rng_state:
|
||||
torch.set_rng_state(fwd_cpu_state)
|
||||
if had_cuda_in_fwd:
|
||||
set_device_states(fwd_gpu_devices, fwd_gpu_states)
|
||||
|
||||
with cached_mode, torch.enable_grad(), torch.cuda.amp.autocast(
|
||||
**gpu_autocast_kwargs
|
||||
), torch.cpu.amp.autocast(
|
||||
**cpu_autocast_kwargs
|
||||
), torch.autograd.graph.saved_tensors_hooks(
|
||||
inner_pack, inner_unpack
|
||||
):
|
||||
_unused = function(*args, **kwargs) # noqa: F841
|
||||
|
||||
if x not in storage:
|
||||
raise RuntimeError(
|
||||
"Attempt to retrieve a tensor saved by autograd multiple times without checkpoint"
|
||||
" recomputation being triggered in between, this is not currently supported. Please"
|
||||
" open an issue with details on your use case so that we can prioritize adding this."
|
||||
)
|
||||
|
||||
return storage[x]
|
||||
|
||||
with caching_mode, torch.autograd.graph.saved_tensors_hooks(pack, unpack):
|
||||
output = function(*args, **kwargs)
|
||||
if torch.cuda._initialized and preserve_rng_state and not had_cuda_in_fwd:
|
||||
# Cuda was not initialized before running the forward, so we didn't
|
||||
# stash the CUDA state.
|
||||
raise RuntimeError(
|
||||
"PyTorch's CUDA state was initialized in the forward pass "
|
||||
"of a Checkpoint, which is not allowed. Please open an issue "
|
||||
"if you need this feature."
|
||||
)
|
||||
|
||||
return output
|
||||
Reference in New Issue
Block a user