# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch import triton import triton.language as tl from xformers.triton.k_activations import ( gelu_grad, leaky_relu_grad, relu_grad, smelu_grad, squared_relu_grad, star_relu_grad, ) # fmt: off @triton.autotune( configs=[ triton.Config({"BLOCK_N": 64}, num_stages=4, num_warps=2), triton.Config({"BLOCK_N": 128}, num_stages=3, num_warps=2), triton.Config({"BLOCK_N": 256}, num_stages=3, num_warps=4), triton.Config({"BLOCK_N": 512}, num_stages=3, num_warps=4), triton.Config({"BLOCK_N": 1024}, num_stages=3, num_warps=4), ], key=["N"], ) @triton.heuristics({ 'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0, }) @triton.jit def kernel_bw( # Pointers to matrices GRAD_ACT, GRAD_OUT, ACT_INPUTS, # Matrix dimensions N, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) stride_gom, stride_aim, # Meta-parameters BLOCK_N: tl.constexpr, EVEN_N: tl.constexpr, ACTIVATION_GRAD: tl.constexpr, ): # fmt: on """ Go over all the activation inputs, compute the corresponding gradient """ # this kernel is relatively simple in terms of scheduling: # - per row (pid_m) # - each program a given chunk on the col axis, # since it's more effective memory and occupancy wise pid_m, pid_n = tl.program_id(axis=0), tl.program_id(axis=1) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # the memory addresses of elements in the first block of # A and W can be computed using numpy-style broadcasting act_input_ptrs = ACT_INPUTS + pid_m * stride_aim + rn # compute the gradient which is related to this activation if EVEN_N: act_in = tl.load(act_input_ptrs) else: act_in = tl.load(act_input_ptrs, mask=rn < N, other=0.0) if ACTIVATION_GRAD == 1: grad_act = relu_grad(act_in) elif ACTIVATION_GRAD == 2: grad_act = leaky_relu_grad(act_in) elif ACTIVATION_GRAD == 3: grad_act = gelu_grad(act_in) elif ACTIVATION_GRAD == 4: grad_act = squared_relu_grad(act_in) elif ACTIVATION_GRAD == 5: grad_act = smelu_grad(act_in) elif ACTIVATION_GRAD == 6: grad_act = star_relu_grad(act_in) else: grad_act = act_in # now read the incoming gradient, the backpropagated one is the multiple of both grad_out_ptrs = GRAD_OUT + pid_m * stride_gom + rn if EVEN_N: grad_out = tl.load(grad_out_ptrs) else: grad_out = tl.load(grad_out_ptrs, mask=rn < N) grad_act *= grad_out # write back result grad_act_ptrs = GRAD_ACT + pid_m * stride_gom + rn tl.store(grad_act_ptrs, grad_act, mask=rn < N) def fused_matmul_backward( grad_out: torch.Tensor, inputs: torch.Tensor, act_in: Optional[torch.Tensor], weight: torch.Tensor, trainable_weight: bool, trainable_bias: bool, activation_grad: int = 0, ): """ Compute grad_in = activation^-1(grad_out) @ weight.transpose() .. note: The weight buffer is transposed on the fly .. note: Activation gradient needs to be a Triton kernel """ # Make sure that we don't have to handle the stride over cols if not grad_out.is_contiguous(): grad_out = grad_out.contiguous() grad_out_ = grad_out if grad_out.ndim == 2 else grad_out.flatten(0, -2) inputs_ = inputs if inputs.ndim == 2 else inputs.flatten(0, -2) assert grad_out_.shape[1] == weight.shape[0], "Incompatible dimensions in between grad_out and weight" M, N = grad_out_.shape N, _ = weight.shape # Compute the gradient for the activation if activation_grad > 0: grad_act = torch.empty_like(grad_out_) # Some activations do not require their inputs to # know of their grad, the downstream grad is enough if act_in is None: act_in = grad_out_ grid = lambda META: (M, triton.cdiv(N, META["BLOCK_N"])) # noqa # fmt: off kernel_bw[grid]( grad_act, grad_out_, act_in, # data ptrs N, # shapes grad_act.stride(0), act_in.stride(0), # strides ACTIVATION_GRAD=activation_grad, # optional fused activation ) # fmt: on # Backpropagation going up, the reference gradient is now # just before the activation grad_out_ = grad_act # The following ops can also be handled by pytorch grad_in = triton.ops.matmul(grad_out_, weight) grad_weight = grad_out_.transpose(1, 0) @ inputs_ if trainable_weight else None grad_bias = torch.sum(grad_out_, dim=0) if trainable_bias else None return grad_in.reshape_as(inputs), grad_weight, grad_bias