First commit
This commit is contained in:
223
pkgs/xformers/ops/indexing.py
Normal file
223
pkgs/xformers/ops/indexing.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from .common import BaseOperator, get_xformers_operator, register_operator
|
||||
|
||||
|
||||
@register_operator
|
||||
class ScaledIndexAddFw(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("scaled_index_addF")
|
||||
OPERATOR_CATEGORY = "indexing"
|
||||
NAME = "scaled_index_addF"
|
||||
|
||||
|
||||
@register_operator
|
||||
class ScaledIndexAddBw(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("scaled_index_addB")
|
||||
OPERATOR_CATEGORY = "indexing"
|
||||
NAME = "scaled_index_addB"
|
||||
|
||||
|
||||
@register_operator
|
||||
class IndexSelect(BaseOperator):
|
||||
OPERATOR = get_xformers_operator("index_select")
|
||||
OPERATOR_CATEGORY = "indexing"
|
||||
NAME = "index_select"
|
||||
|
||||
|
||||
class _ScaledIndexAdd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx,
|
||||
input: torch.Tensor,
|
||||
index: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
scaling: Optional[torch.Tensor],
|
||||
alpha: float,
|
||||
) -> torch.Tensor:
|
||||
ScaledIndexAddFw.OPERATOR(
|
||||
output=input, # in-place
|
||||
input=input,
|
||||
source=source,
|
||||
index=index,
|
||||
source_scaling=scaling,
|
||||
alpha=alpha,
|
||||
)
|
||||
ctx.mark_dirty(input)
|
||||
ctx.save_for_backward(index, scaling, source)
|
||||
ctx.source_shape = source.shape
|
||||
ctx.alpha = alpha
|
||||
return input
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
index, scaling, source = ctx.saved_tensors
|
||||
grad_source = torch.empty_like(grad_output[: index.shape[0]])
|
||||
grad_source_scaling = (
|
||||
torch.empty(
|
||||
ctx.source_shape,
|
||||
dtype=scaling.dtype,
|
||||
device=scaling.device,
|
||||
)
|
||||
if scaling is not None
|
||||
else None
|
||||
)
|
||||
ScaledIndexAddBw.OPERATOR(
|
||||
grad_source=grad_source,
|
||||
grad_source_scaling=grad_source_scaling,
|
||||
grad_output=grad_output,
|
||||
source=source,
|
||||
index=index,
|
||||
source_scaling=scaling,
|
||||
alpha=ctx.alpha,
|
||||
)
|
||||
if grad_source_scaling is not None:
|
||||
grad_source_scaling = grad_source_scaling.sum((0, 1))
|
||||
return (
|
||||
grad_output, # input
|
||||
None, # index
|
||||
grad_source, # source
|
||||
grad_source_scaling, # scaling
|
||||
None, # alpha
|
||||
)
|
||||
|
||||
|
||||
def scaled_index_add(
|
||||
input: torch.Tensor, # [B, M, D]
|
||||
index: torch.Tensor, # [Bi] - int64
|
||||
source: torch.Tensor, # [Bi, M, D]
|
||||
scaling: Optional[torch.Tensor] = None, # [D]
|
||||
alpha: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
In-place scaling+index_add
|
||||
|
||||
Indices in ``index`` are assumed to be unique
|
||||
|
||||
:Note:
|
||||
|
||||
The FW pass is done in-place (``input`` is modified)
|
||||
|
||||
:Note:
|
||||
|
||||
This is experimental and has only been optimized for a few shapes
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return torch.index_add(inp, dim=0, source=scaling * src, index=indices, alpha=alpha)
|
||||
"""
|
||||
return _ScaledIndexAdd.apply(
|
||||
input,
|
||||
index,
|
||||
source,
|
||||
scaling,
|
||||
alpha,
|
||||
)
|
||||
|
||||
|
||||
class _IndexSelectCat(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# type: ignore
|
||||
def forward(
|
||||
ctx,
|
||||
*args: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
assert len(args) % 2 == 0
|
||||
sources = args[: len(args) // 2]
|
||||
indices = args[len(args) // 2 :]
|
||||
output_shape = 0
|
||||
total_source_elements = 0
|
||||
for source, index in zip(sources, indices):
|
||||
output_shape += index.shape[0] * source.shape[1]
|
||||
total_source_elements += source.shape[0] * source.shape[1]
|
||||
output = torch.empty(
|
||||
[output_shape], dtype=sources[0].dtype, device=sources[0].device
|
||||
)
|
||||
output_i = 0
|
||||
for source, index in zip(sources, indices):
|
||||
elements_here = index.shape[0] * source.shape[1]
|
||||
IndexSelect.OPERATOR(
|
||||
output=output[output_i : output_i + elements_here].view(
|
||||
[index.shape[0], source.shape[1]]
|
||||
),
|
||||
source=source,
|
||||
index=index,
|
||||
)
|
||||
output_i += elements_here
|
||||
ctx.save_for_backward(*indices)
|
||||
ctx.total_source_elements = total_source_elements
|
||||
ctx.source_shapes = [s.shape for s in sources]
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
indices = ctx.saved_tensors
|
||||
grad_sources = torch.zeros(
|
||||
[ctx.total_source_elements],
|
||||
dtype=grad_output.dtype,
|
||||
device=grad_output.device,
|
||||
)
|
||||
grad_sources_i = 0
|
||||
grad_output_i = 0
|
||||
gradients = []
|
||||
for source_shape, index in zip(ctx.source_shapes, indices):
|
||||
grad_output_slice = grad_output[
|
||||
grad_output_i : grad_output_i + index.shape[0] * source_shape[1]
|
||||
].reshape([index.shape[0], source_shape[1]])
|
||||
grad_output_i += index.shape[0] * source_shape[1]
|
||||
|
||||
gradient_source = grad_sources[
|
||||
grad_sources_i : grad_sources_i + source_shape[0] * source_shape[1]
|
||||
].reshape(source_shape)
|
||||
grad_sources_i += source_shape[0] * source_shape[1]
|
||||
|
||||
ScaledIndexAddFw.OPERATOR(
|
||||
output=gradient_source.unsqueeze(1),
|
||||
input=None,
|
||||
source=grad_output_slice.unsqueeze(1),
|
||||
index=index,
|
||||
source_scaling=None,
|
||||
alpha=1.0,
|
||||
)
|
||||
gradients.append(gradient_source)
|
||||
return (*gradients, *([None] * len(gradients)))
|
||||
|
||||
|
||||
def index_select_cat(
|
||||
sources: Sequence[torch.Tensor], indices: Sequence[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Indices in ``index`` are assumed to be unique
|
||||
|
||||
:Note:
|
||||
|
||||
This is experimental and has only been optimized for a few shapes
|
||||
|
||||
:Example:
|
||||
|
||||
Given:
|
||||
- ``sources[0]`` of shape ``[S0, D0]``
|
||||
- ``indices[0]`` of shape ``[I0]``
|
||||
- ``sources[1]`` of shape ``[S1, D1]``
|
||||
- ``indices[1]`` of shape ``[I1]``
|
||||
returns a ``torch.Tensor`` of shape ``[I0 * D0 + I1 * D1]``
|
||||
|
||||
:Equivalent pytorch code:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
|
||||
"""
|
||||
return _IndexSelectCat.apply(*sources, *indices)
|
||||
Reference in New Issue
Block a user