# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. # # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch import triton import triton.language as tl from xformers.triton.k_activations import ( gelu, leaky_relu, relu, smelu, squared_relu, star_relu, ) # CREDITS: Initially inspired by the Triton tutorial on matrix multiplications def get_configs(block_k): return [ triton.Config( {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": block_k}, num_stages=4, num_warps=2, ), triton.Config( {"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": block_k}, num_stages=4, num_warps=2, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": block_k}, num_stages=3, num_warps=4, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": block_k}, num_stages=3, num_warps=4, ), triton.Config( {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": block_k}, num_stages=3, num_warps=4, ), triton.Config( {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": block_k}, num_stages=3, num_warps=4, ), triton.Config( {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": block_k}, num_stages=3, num_warps=4, ), # Fails on small GPUS # triton.Config( # {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": block_k}, # num_stages=3, # num_warps=8, # ), # triton.Config( # {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": block_k}, # num_stages=3, # num_warps=8, # ), ] # fmt: off @triton.autotune( configs=[c for block_k in [32, 64] for c in get_configs(block_k)], key=["M", "N", "K"], ) @triton.heuristics({ 'EVEN_N': lambda args: args["N"] % (args['BLOCK_N']) == 0, }) @triton.jit def kernel_fma( # Pointers to matrices OUT, ACT_INPUTS, INPUT, WEIGHT, bias, # Matrix dimensions M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. stride_am is how much to increase a_ptr # by to get the element one row down (A has M rows) stride_om, stride_im, stride_wn, # Meta-parameters BLOCK_M: tl.constexpr, GROUP_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, EVEN_N: tl.constexpr, BIAS: tl.constexpr, SAVE_ACT_INPUTS: tl.constexpr, ACTIVATION: tl.constexpr, is_fp16: tl.constexpr, # autotune ): # fmt: on """ Kernel for computing Out = activation(A x W + C) - Input has shape (M, K) - Weight has shape (K, N) - Bias has shape (N,) - Output has shape (M, N) - ActInputs (optional) has shape (M, N) 'ActInputs' optionally saves the A x W + C intermediate for backward computations This kernel will consolidate over K """ # programs are grouped together to improve L2 hit rate # the logic is that we'll consolidate over K. If the programs were not grouped, # then multiple cols/rows in the result would end up pulling in the same row and lines # from the inputs. By grouping the computation we ensure some data reuse, which the hardware # covers via the L2 cache pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_M) # number of program ids along the M axis num_pid_n = tl.cdiv(N, BLOCK_N) # number of programs ids along the N axis num_pid_in_group = GROUP_M * num_pid_n # number of programs in group group_id = pid // num_pid_in_group # id of the group this program is in first_pid_m = group_id * GROUP_M # row-id of the first program in the group GROUP_M = min( num_pid_m - first_pid_m, GROUP_M ) # if `num_pid_m` isn't divisible by `GROUP_M`, the last group is smaller # *within groups*, programs are ordered in a column-major order # row-id /col-id of the program in the *launch grid* pid_m = first_pid_m + (pid % GROUP_M) pid_n = (pid % num_pid_in_group) // GROUP_M # now compute the block that each program will go through # rm (resp. rn) denotes a range of indices # for rows (resp. col) of C rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) rk = tl.arange(0, BLOCK_K) # the memory addresses of elements can follow numpy broadcasting input_ptrs = INPUT + rm[:, None] * stride_im weight_ptrs = WEIGHT + rn[None, :] * stride_wn # initialize and iteratively update accumulator acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) if BIAS: if EVEN_N: bias = tl.load(bias + rn).to(tl.float32) else: bias = tl.load(bias + rn, mask=rn < N, other=0.0).to(tl.float32) acc += bias[None, :] # block level matrix multiplication. # We fetch a block memory block from both inputs, matmul and accumulate, then repeat mask_rn = rn < N mask_rm = rm < M for i in range(0, K, BLOCK_K): rk = tl.arange(0, BLOCK_K) + i a = tl.load(input_ptrs + rk[None, :], mask=((rk[None, :] < K) & mask_rm[:, None]), other=0.0) w = tl.load(weight_ptrs + rk[:, None], mask=((rk[:, None] < K) & mask_rn[None, :]), other=0.0) acc += tl.dot(a, w) rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # optional: save the activation inputs if SAVE_ACT_INPUTS: act_in_ptrs = ACT_INPUTS + rm[:, None] * stride_om + rn[None, :] tl.store(act_in_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :]) # optional: fused activation (while the data is in shared memory) if ACTIVATION == 1: acc = relu(acc) elif ACTIVATION == 2: acc = leaky_relu(acc) elif ACTIVATION == 3: acc = gelu(acc) elif ACTIVATION == 4: acc = squared_relu(acc) elif ACTIVATION == 5: acc = smelu(acc) elif ACTIVATION == 6: acc = star_relu(acc) # write back result out_ptrs = OUT + rm[:, None] * stride_om + rn[None, :] tl.store(out_ptrs, acc, mask=mask_rm[:, None] & mask_rn[None, :]) # Activation needs to be a triton kernel def fused_matmul( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor], activation=0, save_act_inputs: bool = False ): """ Compute e = activation(x @ weight + bias). This wrapper kicks the `kernel_fma` Triton kernel """ if not x.is_contiguous(): x = x.contiguous() x_ = x if x.ndim == 2 else x.flatten(0, -2) assert ( x_.shape[1] == weight.shape[1] ), f"Incompatible dimensions in between inputs and weight, {x_.shape} - {weight.shape}" assert bias is None or bias.is_contiguous() assert ( bias is None or bias.shape[0] == weight.shape[0] ), "Incompatible dimensions in between weight and bias" assert weight.is_contiguous() M, K = x_.shape N, K = weight.shape outputs = torch.empty((M, N), device=x.device, dtype=x.dtype) act_inputs = torch.empty_like(outputs) if save_act_inputs else x # will not be used in that case # 1D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),) # noqa # fmt: off kernel_fma[grid]( outputs, act_inputs, x_, weight, # data ptrs bias if bias is not None else x, # auto skip bias if not present M, N, K, # shapes outputs.stride(0), x_.stride(0), # strides weight.stride(0), ACTIVATION=activation, # optional fused activation BIAS=bias is not None, # optional fused bias GROUP_M=8, # speed optimization: group the programs SAVE_ACT_INPUTS=save_act_inputs, is_fp16=x_.dtype == torch.float16 ) # fmt: on outputs = outputs if x.ndim == 2 else outputs.reshape(*x.shape[:-1], N) return outputs, act_inputs if save_act_inputs else None