# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. from typing import List, Optional, Sequence, Tuple, Union import torch from .common import _get_storage_base def get_stack_strides( tensors: Sequence[torch.Tensor], dim: int ) -> Optional[Tuple[int, ...]]: """ If the tensors are already stacked on dimension :code:`dim`, \ returns the strides of the stacked tensors. \ Otherwise returns :code:`None`. """ if len(tensors) <= 1 or dim > tensors[0].ndim: return None final_stride = [] for i in range(tensors[0].ndim + 1): if i == dim: final_stride.append( tensors[1].storage_offset() - tensors[0].storage_offset() ) continue if i > dim: i -= 1 final_stride.append(tensors[0].stride(i)) storage_data_ptr: Optional[int] = None for i, x in enumerate(tensors[1:]): # Sanity checks if x.shape != tensors[0].shape: return None if x.stride() != tensors[0].stride(): return None if ( x.storage_offset() != tensors[0].storage_offset() + (i + 1) * final_stride[dim] ): return None if storage_data_ptr is None: storage_data_ptr = _get_storage_base(tensors[0]) # Actual storage check if _get_storage_base(x) != storage_data_ptr: return None return tuple(final_stride) def _stack_or_none_fw( tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int, ) -> Optional[torch.Tensor]: strides = get_stack_strides(tensors, dim) if strides is not None: input_shape = list(tensors[0].shape) input_shape.insert(dim, len(tensors)) return tensors[0].as_strided(input_shape, strides) return None def _stack_fw( tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int, ) -> torch.Tensor: out = _stack_or_none_fw(tensors, dim) if out is None: out = torch.stack(tensors, dim=dim) return out class _Unbind(torch.autograd.Function): """ See function `unbind` """ @staticmethod # type: ignore def forward(ctx, x: torch.Tensor, dim: int): ctx.dim = dim return x.unbind(dim) @classmethod # type: ignore def backward(cls, ctx, *tensors: torch.Tensor): return _stack_fw(tensors, ctx.dim), None class _StackOrNone(torch.autograd.Function): """ See function `stack_or_none` """ @staticmethod # type: ignore def forward(ctx, dim: int, *tensors: torch.Tensor): ctx.dim = dim return _stack_or_none_fw(tensors, dim=dim) @classmethod # type: ignore def backward(cls, ctx, grad: torch.Tensor): return (None, *grad.unbind(dim=ctx.dim)) def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]: """ Does exactly the same as :attr:`torch.unbind` for the forward. In backward, avoids a :attr:`torch.cat` if the gradients are already multiple views of the same storage """ return _Unbind.apply(x, dim) def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor: """ Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated without any memory operation. Otherwise returns None. """ return _StackOrNone.apply(dim, *tensors)