224 lines
6.5 KiB
Python
224 lines
6.5 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from typing import Optional, Sequence
|
|
|
|
import torch
|
|
|
|
from .common import BaseOperator, get_xformers_operator, register_operator
|
|
|
|
|
|
@register_operator
|
|
class ScaledIndexAddFw(BaseOperator):
|
|
OPERATOR = get_xformers_operator("scaled_index_addF")
|
|
OPERATOR_CATEGORY = "indexing"
|
|
NAME = "scaled_index_addF"
|
|
|
|
|
|
@register_operator
|
|
class ScaledIndexAddBw(BaseOperator):
|
|
OPERATOR = get_xformers_operator("scaled_index_addB")
|
|
OPERATOR_CATEGORY = "indexing"
|
|
NAME = "scaled_index_addB"
|
|
|
|
|
|
@register_operator
|
|
class IndexSelect(BaseOperator):
|
|
OPERATOR = get_xformers_operator("index_select")
|
|
OPERATOR_CATEGORY = "indexing"
|
|
NAME = "index_select"
|
|
|
|
|
|
class _ScaledIndexAdd(torch.autograd.Function):
|
|
@staticmethod
|
|
# type: ignore
|
|
def forward(
|
|
ctx,
|
|
input: torch.Tensor,
|
|
index: torch.Tensor,
|
|
source: torch.Tensor,
|
|
scaling: Optional[torch.Tensor],
|
|
alpha: float,
|
|
) -> torch.Tensor:
|
|
ScaledIndexAddFw.OPERATOR(
|
|
output=input, # in-place
|
|
input=input,
|
|
source=source,
|
|
index=index,
|
|
source_scaling=scaling,
|
|
alpha=alpha,
|
|
)
|
|
ctx.mark_dirty(input)
|
|
ctx.save_for_backward(index, scaling, source)
|
|
ctx.source_shape = source.shape
|
|
ctx.alpha = alpha
|
|
return input
|
|
|
|
@staticmethod
|
|
@torch.autograd.function.once_differentiable
|
|
def backward(ctx, grad_output):
|
|
index, scaling, source = ctx.saved_tensors
|
|
grad_source = torch.empty_like(grad_output[: index.shape[0]])
|
|
grad_source_scaling = (
|
|
torch.empty(
|
|
ctx.source_shape,
|
|
dtype=scaling.dtype,
|
|
device=scaling.device,
|
|
)
|
|
if scaling is not None
|
|
else None
|
|
)
|
|
ScaledIndexAddBw.OPERATOR(
|
|
grad_source=grad_source,
|
|
grad_source_scaling=grad_source_scaling,
|
|
grad_output=grad_output,
|
|
source=source,
|
|
index=index,
|
|
source_scaling=scaling,
|
|
alpha=ctx.alpha,
|
|
)
|
|
if grad_source_scaling is not None:
|
|
grad_source_scaling = grad_source_scaling.sum((0, 1))
|
|
return (
|
|
grad_output, # input
|
|
None, # index
|
|
grad_source, # source
|
|
grad_source_scaling, # scaling
|
|
None, # alpha
|
|
)
|
|
|
|
|
|
def scaled_index_add(
|
|
input: torch.Tensor, # [B, M, D]
|
|
index: torch.Tensor, # [Bi] - int64
|
|
source: torch.Tensor, # [Bi, M, D]
|
|
scaling: Optional[torch.Tensor] = None, # [D]
|
|
alpha: float = 1.0,
|
|
) -> torch.Tensor:
|
|
"""
|
|
In-place scaling+index_add
|
|
|
|
Indices in ``index`` are assumed to be unique
|
|
|
|
:Note:
|
|
|
|
The FW pass is done in-place (``input`` is modified)
|
|
|
|
:Note:
|
|
|
|
This is experimental and has only been optimized for a few shapes
|
|
|
|
:Equivalent pytorch code:
|
|
|
|
.. code-block:: python
|
|
|
|
return torch.index_add(inp, dim=0, source=scaling * src, index=indices, alpha=alpha)
|
|
"""
|
|
return _ScaledIndexAdd.apply(
|
|
input,
|
|
index,
|
|
source,
|
|
scaling,
|
|
alpha,
|
|
)
|
|
|
|
|
|
class _IndexSelectCat(torch.autograd.Function):
|
|
@staticmethod
|
|
# type: ignore
|
|
def forward(
|
|
ctx,
|
|
*args: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
assert len(args) % 2 == 0
|
|
sources = args[: len(args) // 2]
|
|
indices = args[len(args) // 2 :]
|
|
output_shape = 0
|
|
total_source_elements = 0
|
|
for source, index in zip(sources, indices):
|
|
output_shape += index.shape[0] * source.shape[1]
|
|
total_source_elements += source.shape[0] * source.shape[1]
|
|
output = torch.empty(
|
|
[output_shape], dtype=sources[0].dtype, device=sources[0].device
|
|
)
|
|
output_i = 0
|
|
for source, index in zip(sources, indices):
|
|
elements_here = index.shape[0] * source.shape[1]
|
|
IndexSelect.OPERATOR(
|
|
output=output[output_i : output_i + elements_here].view(
|
|
[index.shape[0], source.shape[1]]
|
|
),
|
|
source=source,
|
|
index=index,
|
|
)
|
|
output_i += elements_here
|
|
ctx.save_for_backward(*indices)
|
|
ctx.total_source_elements = total_source_elements
|
|
ctx.source_shapes = [s.shape for s in sources]
|
|
return output
|
|
|
|
@staticmethod
|
|
@torch.autograd.function.once_differentiable
|
|
def backward(ctx, grad_output):
|
|
indices = ctx.saved_tensors
|
|
grad_sources = torch.zeros(
|
|
[ctx.total_source_elements],
|
|
dtype=grad_output.dtype,
|
|
device=grad_output.device,
|
|
)
|
|
grad_sources_i = 0
|
|
grad_output_i = 0
|
|
gradients = []
|
|
for source_shape, index in zip(ctx.source_shapes, indices):
|
|
grad_output_slice = grad_output[
|
|
grad_output_i : grad_output_i + index.shape[0] * source_shape[1]
|
|
].reshape([index.shape[0], source_shape[1]])
|
|
grad_output_i += index.shape[0] * source_shape[1]
|
|
|
|
gradient_source = grad_sources[
|
|
grad_sources_i : grad_sources_i + source_shape[0] * source_shape[1]
|
|
].reshape(source_shape)
|
|
grad_sources_i += source_shape[0] * source_shape[1]
|
|
|
|
ScaledIndexAddFw.OPERATOR(
|
|
output=gradient_source.unsqueeze(1),
|
|
input=None,
|
|
source=grad_output_slice.unsqueeze(1),
|
|
index=index,
|
|
source_scaling=None,
|
|
alpha=1.0,
|
|
)
|
|
gradients.append(gradient_source)
|
|
return (*gradients, *([None] * len(gradients)))
|
|
|
|
|
|
def index_select_cat(
|
|
sources: Sequence[torch.Tensor], indices: Sequence[torch.Tensor]
|
|
) -> torch.Tensor:
|
|
"""
|
|
Indices in ``index`` are assumed to be unique
|
|
|
|
:Note:
|
|
|
|
This is experimental and has only been optimized for a few shapes
|
|
|
|
:Example:
|
|
|
|
Given:
|
|
- ``sources[0]`` of shape ``[S0, D0]``
|
|
- ``indices[0]`` of shape ``[I0]``
|
|
- ``sources[1]`` of shape ``[S1, D1]``
|
|
- ``indices[1]`` of shape ``[I1]``
|
|
returns a ``torch.Tensor`` of shape ``[I0 * D0 + I1 * D1]``
|
|
|
|
:Equivalent pytorch code:
|
|
|
|
.. code-block:: python
|
|
|
|
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
|
|
"""
|
|
return _IndexSelectCat.apply(*sources, *indices)
|