init
This commit is contained in:
144
torch_vacc/vacc/amp/autocast_mode.py
Normal file
144
torch_vacc/vacc/amp/autocast_mode.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import collections
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
HAS_NUMPY = True
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
__all__ = ["autocast", "custom_fwd", "custom_bwd"]
|
||||
|
||||
|
||||
class autocast(torch.amp.autocast_mode.autocast):
|
||||
r"""See :class:`torch.autocast`.
|
||||
|
||||
``torch.vacc.amp.autocast(args...)`` is equivalent to ``torch.autocast("vacc", args...)``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled: bool = True,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
cache_enabled: bool = True,
|
||||
):
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = "vacc"
|
||||
self.fast_dtype = dtype
|
||||
return
|
||||
super().__init__(
|
||||
"vacc", enabled=enabled, dtype=dtype, cache_enabled=cache_enabled
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
if torch._jit_internal.is_scripting():
|
||||
return self
|
||||
return super().__enter__()
|
||||
|
||||
# TODO: discuss a unified TorchScript-friendly API for autocast
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
||||
if torch._jit_internal.is_scripting():
|
||||
return
|
||||
return super().__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def __call__(self, func):
|
||||
if torch._jit_internal.is_scripting():
|
||||
return func
|
||||
return super().__call__(func)
|
||||
|
||||
|
||||
# Casts Tensors and containers of Tensors. Special-cases passthroughs for strings and np.ndarrays, which
|
||||
# may be falsely detected as "Iterables."
|
||||
def _cast(value, dtype):
|
||||
if isinstance(value, torch.Tensor):
|
||||
is_eligible = (
|
||||
value.is_floating_point()
|
||||
and value.is_vacc
|
||||
and (value.dtype is not torch.float64)
|
||||
)
|
||||
return value.to(dtype) if is_eligible else value
|
||||
elif isinstance(value, (str, bytes)):
|
||||
return value
|
||||
elif HAS_NUMPY and isinstance(value, np.ndarray):
|
||||
return value
|
||||
elif isinstance(value, collections.abc.Mapping):
|
||||
return {_cast(k, dtype): _cast(v, dtype) for k, v in value.items()}
|
||||
elif isinstance(value, collections.abc.Iterable):
|
||||
iterable = (_cast(v, dtype) for v in value)
|
||||
if isinstance(value, (list, tuple)):
|
||||
return type(value)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
return value
|
||||
|
||||
|
||||
# custom_fwd is a decorator that may or may not be used with arguments, following
|
||||
# https://github.com/dabeaz/python-cookbook/tree/master/src/9/defining_a_decorator_that_takes_an_optional_argument.
|
||||
# this works:
|
||||
# @custom_fwd
|
||||
# def forward(...):
|
||||
# this also works:
|
||||
# @custom_fwd(cast_inputs=torch.float)
|
||||
# def forward(...):
|
||||
def custom_fwd(fwd=None, *, cast_inputs=None):
|
||||
"""
|
||||
Create a helper decorator for ``forward`` methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
|
||||
Args:
|
||||
cast_inputs (:class:`torch.dtype` or None, optional, default=None): If not ``None``,
|
||||
when ``forward`` runs in an autocast-enabled region, casts incoming
|
||||
floating-point VACC Tensors to the target dtype (non-floating-point Tensors are not affected),
|
||||
then executes ``forward`` with autocast disabled.
|
||||
If ``None``, ``forward``'s internal ops execute with the current autocast state.
|
||||
|
||||
.. note::
|
||||
If the decorated ``forward`` is called outside an autocast-enabled region,
|
||||
:func:`custom_fwd<custom_fwd>` is a no-op and ``cast_inputs`` has no effect.
|
||||
"""
|
||||
if fwd is None:
|
||||
return functools.partial(custom_fwd, cast_inputs=cast_inputs)
|
||||
|
||||
@functools.wraps(fwd)
|
||||
def decorate_fwd(*args, **kwargs):
|
||||
args[0]._dtype = torch.get_autocast_gpu_dtype()
|
||||
if cast_inputs is None:
|
||||
args[0]._fwd_used_autocast = torch.is_autocast_enabled()
|
||||
return fwd(*args, **kwargs)
|
||||
else:
|
||||
autocast_context = torch.is_autocast_enabled()
|
||||
args[0]._fwd_used_autocast = False
|
||||
if autocast_context:
|
||||
with autocast(enabled=False):
|
||||
return fwd(*_cast(args, cast_inputs), **_cast(kwargs, cast_inputs))
|
||||
else:
|
||||
return fwd(*args, **kwargs)
|
||||
|
||||
return decorate_fwd
|
||||
|
||||
|
||||
# Autograd ensures incoming gradients are the same type as forward outputs. Allowing a separate
|
||||
# cast_inputs argument on custom_bwd is unnecessary and could cause errors if it doesn't match
|
||||
# cast_inputs supplied to custom_fwd.
|
||||
def custom_bwd(bwd):
|
||||
"""Create a helper decorator for backward methods of custom autograd functions.
|
||||
|
||||
Autograd functions are subclasses of :class:`torch.autograd.Function`.
|
||||
Ensures that ``backward`` executes with the same autocast state as ``forward``.
|
||||
See the :ref:`example page<amp-custom-examples>` for more detail.
|
||||
"""
|
||||
|
||||
@functools.wraps(bwd)
|
||||
def decorate_bwd(*args, **kwargs):
|
||||
with autocast(enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype):
|
||||
return bwd(*args, **kwargs)
|
||||
|
||||
return decorate_bwd
|
||||
Reference in New Issue
Block a user