Files
2025-08-05 19:02:46 +08:00

224 lines
6.5 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Sequence
import torch
from .common import BaseOperator, get_xformers_operator, register_operator
@register_operator
class ScaledIndexAddFw(BaseOperator):
OPERATOR = get_xformers_operator("scaled_index_addF")
OPERATOR_CATEGORY = "indexing"
NAME = "scaled_index_addF"
@register_operator
class ScaledIndexAddBw(BaseOperator):
OPERATOR = get_xformers_operator("scaled_index_addB")
OPERATOR_CATEGORY = "indexing"
NAME = "scaled_index_addB"
@register_operator
class IndexSelect(BaseOperator):
OPERATOR = get_xformers_operator("index_select")
OPERATOR_CATEGORY = "indexing"
NAME = "index_select"
class _ScaledIndexAdd(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx,
input: torch.Tensor,
index: torch.Tensor,
source: torch.Tensor,
scaling: Optional[torch.Tensor],
alpha: float,
) -> torch.Tensor:
ScaledIndexAddFw.OPERATOR(
output=input, # in-place
input=input,
source=source,
index=index,
source_scaling=scaling,
alpha=alpha,
)
ctx.mark_dirty(input)
ctx.save_for_backward(index, scaling, source)
ctx.source_shape = source.shape
ctx.alpha = alpha
return input
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
index, scaling, source = ctx.saved_tensors
grad_source = torch.empty_like(grad_output[: index.shape[0]])
grad_source_scaling = (
torch.empty(
ctx.source_shape,
dtype=scaling.dtype,
device=scaling.device,
)
if scaling is not None
else None
)
ScaledIndexAddBw.OPERATOR(
grad_source=grad_source,
grad_source_scaling=grad_source_scaling,
grad_output=grad_output,
source=source,
index=index,
source_scaling=scaling,
alpha=ctx.alpha,
)
if grad_source_scaling is not None:
grad_source_scaling = grad_source_scaling.sum((0, 1))
return (
grad_output, # input
None, # index
grad_source, # source
grad_source_scaling, # scaling
None, # alpha
)
def scaled_index_add(
input: torch.Tensor, # [B, M, D]
index: torch.Tensor, # [Bi] - int64
source: torch.Tensor, # [Bi, M, D]
scaling: Optional[torch.Tensor] = None, # [D]
alpha: float = 1.0,
) -> torch.Tensor:
"""
In-place scaling+index_add
Indices in ``index`` are assumed to be unique
:Note:
The FW pass is done in-place (``input`` is modified)
:Note:
This is experimental and has only been optimized for a few shapes
:Equivalent pytorch code:
.. code-block:: python
return torch.index_add(inp, dim=0, source=scaling * src, index=indices, alpha=alpha)
"""
return _ScaledIndexAdd.apply(
input,
index,
source,
scaling,
alpha,
)
class _IndexSelectCat(torch.autograd.Function):
@staticmethod
# type: ignore
def forward(
ctx,
*args: torch.Tensor,
) -> torch.Tensor:
assert len(args) % 2 == 0
sources = args[: len(args) // 2]
indices = args[len(args) // 2 :]
output_shape = 0
total_source_elements = 0
for source, index in zip(sources, indices):
output_shape += index.shape[0] * source.shape[1]
total_source_elements += source.shape[0] * source.shape[1]
output = torch.empty(
[output_shape], dtype=sources[0].dtype, device=sources[0].device
)
output_i = 0
for source, index in zip(sources, indices):
elements_here = index.shape[0] * source.shape[1]
IndexSelect.OPERATOR(
output=output[output_i : output_i + elements_here].view(
[index.shape[0], source.shape[1]]
),
source=source,
index=index,
)
output_i += elements_here
ctx.save_for_backward(*indices)
ctx.total_source_elements = total_source_elements
ctx.source_shapes = [s.shape for s in sources]
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
indices = ctx.saved_tensors
grad_sources = torch.zeros(
[ctx.total_source_elements],
dtype=grad_output.dtype,
device=grad_output.device,
)
grad_sources_i = 0
grad_output_i = 0
gradients = []
for source_shape, index in zip(ctx.source_shapes, indices):
grad_output_slice = grad_output[
grad_output_i : grad_output_i + index.shape[0] * source_shape[1]
].reshape([index.shape[0], source_shape[1]])
grad_output_i += index.shape[0] * source_shape[1]
gradient_source = grad_sources[
grad_sources_i : grad_sources_i + source_shape[0] * source_shape[1]
].reshape(source_shape)
grad_sources_i += source_shape[0] * source_shape[1]
ScaledIndexAddFw.OPERATOR(
output=gradient_source.unsqueeze(1),
input=None,
source=grad_output_slice.unsqueeze(1),
index=index,
source_scaling=None,
alpha=1.0,
)
gradients.append(gradient_source)
return (*gradients, *([None] * len(gradients)))
def index_select_cat(
sources: Sequence[torch.Tensor], indices: Sequence[torch.Tensor]
) -> torch.Tensor:
"""
Indices in ``index`` are assumed to be unique
:Note:
This is experimental and has only been optimized for a few shapes
:Example:
Given:
- ``sources[0]`` of shape ``[S0, D0]``
- ``indices[0]`` of shape ``[I0]``
- ``sources[1]`` of shape ``[S1, D1]``
- ``indices[1]`` of shape ``[I1]``
returns a ``torch.Tensor`` of shape ``[I0 * D0 + I1 * D1]``
:Equivalent pytorch code:
.. code-block:: python
return torch.cat([s[i.long()].flatten() for s, i in zip(sources, indices)], dim=0)
"""
return _IndexSelectCat.apply(*sources, *indices)