739 lines
26 KiB
Python
739 lines
26 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
|
|
|
|
import torch
|
|
|
|
from ..common import _has_triton21, register_operator
|
|
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
|
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
|
|
|
|
|
|
def _strides(x: torch.Tensor, *stride_names: str):
|
|
assert x.ndim == len(stride_names)
|
|
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
|
|
|
|
|
|
if TYPE_CHECKING or _has_triton21():
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs
|
|
|
|
@triton.jit
|
|
def _fwd_kernel_splitK(
|
|
Q,
|
|
K,
|
|
V,
|
|
sm_scale,
|
|
Out_splitK, # [B, H, split_k, Mq, K]
|
|
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
|
Seq_len,
|
|
stride_qz,
|
|
stride_qm,
|
|
stride_qg,
|
|
stride_qh,
|
|
stride_qk,
|
|
stride_kz,
|
|
stride_kn,
|
|
stride_kg,
|
|
stride_kh,
|
|
stride_kk,
|
|
stride_vz,
|
|
stride_vn,
|
|
stride_vg,
|
|
stride_vh,
|
|
stride_vk,
|
|
stride_osk_zhg,
|
|
stride_osk_s,
|
|
stride_osk_m,
|
|
stride_osk_k,
|
|
stride_mzhg,
|
|
stride_m2,
|
|
stride_ms,
|
|
stride_mm,
|
|
Z,
|
|
N_CTX_Q,
|
|
N_CTX_K,
|
|
BLOCK_N_PER_SPLIT,
|
|
H: tl.constexpr,
|
|
G: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
BOUNDS_CHECKS_N: tl.constexpr,
|
|
USE_SEQ_LEN: tl.constexpr,
|
|
PACKED_PER_VAL: tl.constexpr = 1,
|
|
N_GROUPS: tl.constexpr = 1,
|
|
):
|
|
"""This kernel can accept non-quantized or int4-quantized keys/values.
|
|
PACKED_PER_VAL determines the quantization type:
|
|
- PACKED_PER_VAL == 1 means no quantization
|
|
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
|
|
For the quantized case K/V should be int32 tensors.
|
|
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
|
|
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
|
|
So K[B, H, M, :] has a form
|
|
[ quant_coef0, quant_coef1, ...|
|
|
group0_quant_value0, group0_quant_value1,... |
|
|
group1_quant_value0, group1_quant_value1,...]
|
|
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
|
|
|
|
Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs
|
|
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
|
|
See how FwOp.apply does it below.
|
|
"""
|
|
tl.static_assert(
|
|
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
|
|
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
|
|
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
|
|
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
|
|
)
|
|
tl.static_assert(
|
|
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
|
|
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
|
|
)
|
|
|
|
QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
|
|
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
|
|
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
|
|
|
|
start_m = tl.program_id(0)
|
|
off_zhg = tl.program_id(1)
|
|
off_z = off_zhg // (H * G)
|
|
off_h = (off_zhg // G) % H
|
|
off_g = off_zhg % G
|
|
splitk_idx = tl.program_id(2)
|
|
|
|
lo = splitk_idx * BLOCK_N_PER_SPLIT
|
|
if USE_SEQ_LEN:
|
|
kv_len = tl.load(Seq_len + off_z)
|
|
else:
|
|
kv_len = N_CTX_K
|
|
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
|
|
|
|
Q_block_ptr = tl.make_block_ptr(
|
|
base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg,
|
|
shape=(N_CTX_Q, D_PER_GROUP),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, D_PER_GROUP),
|
|
order=(1, 0),
|
|
)
|
|
|
|
k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg
|
|
# Additional shift by 1 along the last dimension in the quantized case, since
|
|
# the first element along that dim contains packed quantization coefficients.
|
|
K_block_ptr = tl.make_block_ptr(
|
|
base=k_base + stride_kk * QUANTIZED * N_GROUPS,
|
|
shape=(PACKED_D_PER_GROUP, hi),
|
|
strides=(stride_kk, stride_kn),
|
|
offsets=(0, lo),
|
|
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
|
|
order=(0, 1),
|
|
)
|
|
v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg
|
|
V_block_ptr = tl.make_block_ptr(
|
|
base=v_base + stride_vk * QUANTIZED * N_GROUPS,
|
|
shape=(hi, PACKED_D_PER_GROUP),
|
|
strides=(stride_vn, stride_vk),
|
|
offsets=(lo, 0),
|
|
block_shape=(BLOCK_N, PACKED_D_PER_GROUP),
|
|
order=(1, 0),
|
|
)
|
|
|
|
if QUANTIZED:
|
|
# Pointers to quantization coefficients. Even those they are 1D,
|
|
# we have to use block pointers, since usual pointers
|
|
# don't support boundary checks
|
|
K_scale_shift_block_ptr = tl.make_block_ptr(
|
|
base=k_base,
|
|
shape=(1, hi),
|
|
strides=(stride_kk, stride_kn),
|
|
offsets=(0, lo),
|
|
block_shape=(1, BLOCK_N),
|
|
order=(0, 1),
|
|
)
|
|
V_scale_shift_block_ptr = tl.make_block_ptr(
|
|
base=v_base,
|
|
shape=(hi, 1),
|
|
strides=(stride_vn, stride_vk),
|
|
offsets=(lo, 0),
|
|
block_shape=(BLOCK_N, 1),
|
|
order=(1, 0),
|
|
)
|
|
else:
|
|
K_scale_shift_block_ptr = None
|
|
V_scale_shift_block_ptr = None
|
|
|
|
# initialize pointer to m and l
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
|
|
|
# Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs.
|
|
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
|
|
# This is a solution for Triton native lack of support for lists of tensors.
|
|
acc: "VAR_ARGS_ARRAY" # noqa: F821
|
|
|
|
for i in range(len(acc)): # noqa: F821
|
|
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
|
|
# scale sm_scale by log_2(e) and use
|
|
# 2^x instead of exp in the loop because CSE and LICM
|
|
# don't work as expected with `exp` in the loop
|
|
qk_scale = sm_scale * 1.44269504
|
|
# load q: it will stay in SRAM throughout
|
|
q: "VAR_ARGS_ARRAY" # noqa: F821
|
|
for i in range(len(acc)): # noqa: F821
|
|
q[i] = tl.load( # noqa: F821
|
|
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
|
|
)
|
|
# loop over k, v and update accumulator
|
|
for start_n in range(lo, hi, BLOCK_N):
|
|
k: "VAR_ARGS_ARRAY" # noqa: F821
|
|
v: "VAR_ARGS_ARRAY" # noqa: F821
|
|
for i in range(len(acc)): # noqa: F821
|
|
k[i], v[i] = load_dequantize_k_v_group( # noqa: F821
|
|
K_block_ptr,
|
|
V_block_ptr,
|
|
K_scale_shift_block_ptr,
|
|
V_scale_shift_block_ptr,
|
|
BOUNDS_CHECKS_N,
|
|
PACKED_PER_VAL,
|
|
PACKED_D_PER_GROUP,
|
|
Q.dtype.element_ty,
|
|
i,
|
|
)
|
|
|
|
# -- compute qk ---
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
for i in range(len(acc)): # noqa: F821
|
|
qk += tl.dot(q[i], k[i]) # noqa: F821
|
|
qk *= qk_scale
|
|
|
|
# TODO: This is slow, and only needed at the last iteration.
|
|
# Maybe we can unroll the last iteration instead?
|
|
if BOUNDS_CHECKS_N:
|
|
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
|
|
# -- compute scaling constant ---
|
|
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
|
alpha = tl.math.exp2(m_i - m_i_new)
|
|
p = tl.math.exp2(qk - m_i_new[:, None])
|
|
|
|
# -- update m_i and l_i --
|
|
l_i = l_i * alpha + tl.sum(p, 1)
|
|
m_i = m_i_new
|
|
p = p.to(Q.dtype.element_ty)
|
|
|
|
# -- scale and update acc --
|
|
for i in range(len(acc)): # noqa: F821
|
|
acc[i] *= alpha[:, None] # noqa: F821
|
|
acc[i] += tl.dot(p, v[i]) # noqa: F821
|
|
# update pointers
|
|
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
|
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
|
if PACKED_PER_VAL > 1:
|
|
K_scale_shift_block_ptr = tl.advance(
|
|
K_scale_shift_block_ptr, (0, BLOCK_N)
|
|
)
|
|
V_scale_shift_block_ptr = tl.advance(
|
|
V_scale_shift_block_ptr, (BLOCK_N, 0)
|
|
)
|
|
|
|
# write back O
|
|
O_block_ptr = tl.make_block_ptr(
|
|
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
|
|
shape=(N_CTX_Q, D_PER_GROUP),
|
|
strides=(stride_osk_m, 1),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, D_PER_GROUP),
|
|
order=(1, 0),
|
|
)
|
|
for i in range(len(acc)): # noqa: F821
|
|
tl.store(
|
|
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
|
|
acc[i], # noqa: F821
|
|
boundary_check=(0,),
|
|
)
|
|
# Write metadata for split-K reduction
|
|
Metadata_ptr = (
|
|
Metadata
|
|
+ off_zhg * stride_mzhg
|
|
+ splitk_idx * stride_ms
|
|
+ start_m * BLOCK_M
|
|
+ tl.arange(0, BLOCK_M)
|
|
)
|
|
tl.store(Metadata_ptr, m_i)
|
|
tl.store(Metadata_ptr + stride_m2, l_i)
|
|
|
|
@triton.jit
|
|
def load_dequantize_k_v_group(
|
|
K_block_ptr,
|
|
V_block_ptr,
|
|
K_scale_shift_block_ptr,
|
|
V_scale_shift_block_ptr,
|
|
BOUNDS_CHECKS_N: tl.constexpr,
|
|
PACKED_PER_VAL: tl.constexpr,
|
|
PACKED_D_PER_GROUP: tl.constexpr,
|
|
dtype: tl.constexpr,
|
|
group_id: tl.constexpr,
|
|
):
|
|
"""Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading.
|
|
If quantization is group-wise, use group_id to advance the pointers to the current group.
|
|
"""
|
|
# Advance to the current quantization group
|
|
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
|
|
V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id))
|
|
|
|
# -- load k, v --
|
|
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
|
|
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
|
|
|
|
if PACKED_PER_VAL > 1:
|
|
# K/V are quantized, load quantization coefficients and dequantize
|
|
|
|
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
|
|
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
|
|
|
|
k_scale_shift = tl.load(
|
|
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
|
|
)
|
|
v_scale_shift = tl.load(
|
|
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
|
|
)
|
|
|
|
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
|
|
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
|
|
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype)
|
|
k_t = dequantize(
|
|
tl.trans(k),
|
|
tl.trans(k_scale),
|
|
tl.trans(k_shift),
|
|
PACKED_PER_VAL,
|
|
).to(dtype)
|
|
k = tl.trans(k_t)
|
|
return k, v
|
|
|
|
@triton.jit
|
|
def cast_uint32_to_half2(scale_shift):
|
|
"""Extract two float16 packed into one int32"""
|
|
scale = scale_shift & 0xFFFF
|
|
shift = scale_shift >> 16
|
|
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
|
|
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
|
|
return scale, shift
|
|
|
|
@triton.jit
|
|
def dequantize(
|
|
x_,
|
|
scale,
|
|
shift,
|
|
PACKED_PER_VAL: tl.constexpr = 8,
|
|
):
|
|
"""PACKED_PER_VAL is the number of values packed into each element x_.
|
|
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
|
|
"""
|
|
|
|
# Axis along which offsets are applied matters here
|
|
# It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
|
|
# and expand K/V to that shape before applying offsets
|
|
# However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2
|
|
# Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends
|
|
# on the implementation details which might change in the future.
|
|
# Ideally we would like to use tl.reshape, but it's not implemented yet.
|
|
# See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501
|
|
|
|
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
|
|
# scale: (BLOCK_N, 1)
|
|
# offsets: (PACKED_PER_VAL,)
|
|
BLOCK_N: tl.constexpr = x_.shape[0]
|
|
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
|
|
offsets = tl.arange(0, PACKED_PER_VAL) * 4
|
|
quant_offset = (
|
|
x_[:, None, :] >> offsets[None, :, None]
|
|
) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
|
|
|
|
quant_offset = tl.view(
|
|
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
|
|
)
|
|
# Trick - instead of converting int4 to float16 we view it as float16
|
|
# and then multiply by 32768 * 512 == 2**24
|
|
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
|
|
quant_offset = (quant_offset * 32768.0).to(tl.float16)
|
|
scale_512 = scale * 512
|
|
|
|
dequant = quant_offset * scale_512 + shift
|
|
return dequant
|
|
|
|
@triton.jit
|
|
def _splitK_reduce(
|
|
Out_splitK, # [B, H, split_k, Mq, K]
|
|
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
|
Out, # [B, H, M, K]
|
|
LSE, # [B, H, M]
|
|
split_k,
|
|
stride_osk_zhg,
|
|
stride_osk_s,
|
|
stride_osk_m,
|
|
stride_osk_k,
|
|
stride_mzhg,
|
|
stride_m2,
|
|
stride_ms,
|
|
stride_mm,
|
|
stride_oz,
|
|
stride_oh,
|
|
stride_og,
|
|
stride_om,
|
|
stride_ok,
|
|
stride_lse_zhg,
|
|
stride_lse_m,
|
|
BLOCK_SIZE: tl.constexpr,
|
|
H: tl.constexpr,
|
|
G: tl.constexpr,
|
|
):
|
|
off_zhg = tl.program_id(0)
|
|
off_z = off_zhg // (H * G)
|
|
off_h = (off_zhg // G) % H
|
|
off_g = off_zhg % G
|
|
off_m = tl.program_id(1)
|
|
|
|
Out_splitK_ptr = (
|
|
Out_splitK
|
|
+ stride_osk_zhg * off_zhg
|
|
+ stride_osk_m * off_m
|
|
+ tl.arange(0, BLOCK_SIZE)
|
|
)
|
|
Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m
|
|
m = tl.load(Metadata_ptr)
|
|
l_sum = tl.load(Metadata_ptr + stride_m2)
|
|
acc = tl.load(Out_splitK_ptr)
|
|
|
|
for split_k_idx in range(1, split_k):
|
|
Metadata_ptr = Metadata_ptr + stride_ms
|
|
Out_splitK_ptr = Out_splitK_ptr + stride_osk_s
|
|
|
|
m_k = tl.load(Metadata_ptr)
|
|
l_k = tl.load(Metadata_ptr + stride_m2)
|
|
acc_k = tl.load(Out_splitK_ptr)
|
|
|
|
m_new = tl.maximum(m, m_k)
|
|
if m_k < m:
|
|
# Scale incoming values
|
|
alpha = tl.math.exp2(m_k - m_new)
|
|
acc_k = acc_k * alpha
|
|
l_k = l_k * alpha
|
|
else:
|
|
# Scale our values
|
|
alpha = tl.math.exp2(m - m_new)
|
|
acc = acc * alpha
|
|
l_sum = l_sum * alpha
|
|
|
|
m = m_new
|
|
l_sum = l_sum + l_k
|
|
acc = acc + acc_k
|
|
|
|
acc = acc / l_sum
|
|
Out_ptr = (
|
|
Out
|
|
+ stride_oz * off_z
|
|
+ stride_oh * off_h
|
|
+ stride_og * off_g
|
|
+ stride_om * off_m
|
|
+ tl.arange(0, BLOCK_SIZE)
|
|
)
|
|
tl.store(Out_ptr, acc)
|
|
|
|
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
|
|
tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504)
|
|
|
|
else:
|
|
_fwd_kernel_splitK = None
|
|
_splitK_reduce = None
|
|
|
|
|
|
@register_operator
|
|
class FwOp(AttentionFwOpBase):
|
|
"""Flash-Attention with Split-K. Supports fused int-4 K/V quantization.
|
|
Quantized path will be taken if input K/V have type int32.
|
|
|
|
Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
|
|
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
|
|
Quantization coefficients (scale and shift) are represented as two
|
|
float16 constants per group, packed into int32. Quantization coefficients of
|
|
all groups are placed at the beginning of the row. So, if unquantized K/V have head
|
|
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
|
|
and dtype int32.
|
|
Pseudocode for dequantizing one row can look like:
|
|
group_size = D // 8
|
|
for i in range(NUM_GROUPS):
|
|
group_start = NUM_GROUPS + i * group_size
|
|
group_quant = K[..., group_start: group_start + group_size]
|
|
scale, shift = unpack_int32_into_float16x2(group_quant[0])
|
|
group_dequant = group_quant[..., 1:] * scale + shift
|
|
...
|
|
|
|
"""
|
|
|
|
OPERATOR = _fwd_kernel_splitK
|
|
SUPPORTED_DEVICES = {"cuda"}
|
|
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
|
SUPPORTED_DTYPES = {
|
|
torch.half,
|
|
torch.bfloat16,
|
|
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
|
|
SUPPORTED_MAX_K = 128
|
|
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
|
type(None),
|
|
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
|
}
|
|
SUPPORTS_DROPOUT = False
|
|
SUPPORTS_CUSTOM_SCALE = True
|
|
SUPPORTS_BMGHK = True
|
|
NAME = "triton_splitKF"
|
|
|
|
SPLIT_K: Optional[int] = None
|
|
BLOCK_M = 16
|
|
BLOCK_N = 64
|
|
|
|
NUM_GROUPS = 1 # Default quantization is row-wise
|
|
|
|
@classmethod
|
|
def shape_not_supported_reasons(
|
|
cls, Mq: int, Mkv: int, K: int, Kv: int
|
|
) -> List[str]:
|
|
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
|
if K not in {16, 32, 64, 128}:
|
|
reasons.append(f"Embed dim {K} not supported")
|
|
return reasons
|
|
|
|
@classmethod
|
|
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
|
reasons = super(FwOp, cls).not_supported_reasons(d)
|
|
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
|
if d.key.dtype != torch.int32:
|
|
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
|
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
|
if cls.OPERATOR is None:
|
|
reasons.append("triton is not available")
|
|
if d.device.type == "cuda":
|
|
# Has only been tested on 8.0 / 9.0.
|
|
if torch.cuda.get_device_capability(d.device) < (8, 0):
|
|
reasons.append(
|
|
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
|
)
|
|
|
|
q_len = d.query.shape[1]
|
|
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
|
seqinfo = d.attn_bias.q_seqinfo
|
|
if q_len != seqinfo.seqstart_py[-1]:
|
|
reasons.append(
|
|
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
|
|
)
|
|
q_len = seqinfo.min_seqlen
|
|
if q_len != seqinfo.max_seqlen:
|
|
reasons.append(
|
|
"Variable query len is not supported in the presence of causal mask."
|
|
)
|
|
|
|
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
|
|
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
|
|
reasons.append("multiquery is only supported with query seqlen=1")
|
|
|
|
if d.attn_bias is not None and q_len > 1:
|
|
reasons.append(
|
|
"query with seqlen > 1 is not supported in the presence of causal mask"
|
|
)
|
|
return reasons
|
|
|
|
@classmethod
|
|
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
|
|
"""Heuristic for the number of splits"""
|
|
bh = B * H
|
|
split_k = max(Mk, 1024) // bh
|
|
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
|
|
while split_k > 0 and Mk / split_k < max_chunk_size:
|
|
split_k = split_k // 2
|
|
split_k = min(split_k, 64)
|
|
split_k = max(split_k, 1)
|
|
return split_k
|
|
|
|
@classmethod
|
|
def apply(
|
|
cls, inp: Inputs, needs_gradient: bool
|
|
) -> Tuple[torch.Tensor, Optional[Context]]:
|
|
attn_bias = inp.attn_bias
|
|
seq_len = None
|
|
q, k, v = inp.get_qkv_in_bmghk()
|
|
|
|
if attn_bias is not None:
|
|
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
|
# TODO: do we really need to do this cast? seems fishy but
|
|
# I just copied it from the decoder.py
|
|
attn_bias.k_seqinfo.to(inp.query.device)
|
|
attn_bias.q_seqinfo.to(inp.query.device)
|
|
seq_len = attn_bias.k_seqinfo.seqlen
|
|
B = len(seq_len)
|
|
G, H, Kq = q.shape[-3:]
|
|
Kkv = v.shape[-1]
|
|
|
|
# assume kv has been padded
|
|
q = q.reshape(B, -1, G, H, Kq)
|
|
k = k.reshape(B, -1, G, H, Kkv)
|
|
v = v.reshape(B, -1, G, H, Kkv)
|
|
|
|
# Transpose in the case of MQA/GQA
|
|
mqa_swap_seqlen_head = False
|
|
if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0:
|
|
mqa_swap_seqlen_head = True
|
|
assert q.shape[1] == 1
|
|
q = q.transpose(1, 3)
|
|
k = k[:, :, :, :1]
|
|
v = v[:, :, :, :1]
|
|
|
|
if k.dtype == torch.int32:
|
|
# Quantized K/V
|
|
PACKED_PER_VAL = 8
|
|
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
|
|
else:
|
|
Lk = k.shape[-1]
|
|
PACKED_PER_VAL = 1
|
|
|
|
B, Mk, G, H, Kkv = k.shape
|
|
B, M, G, H, Kq = q.shape
|
|
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
|
|
|
|
BLOCK_M = cls.BLOCK_M
|
|
BLOCK_N = cls.BLOCK_N
|
|
if cls.SPLIT_K is not None:
|
|
split_k = cls.SPLIT_K
|
|
else:
|
|
# Use heuristics
|
|
split_k = cls.get_split_k(B, H, Mk)
|
|
|
|
M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M
|
|
o_splitk = torch.empty(
|
|
[B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device
|
|
)
|
|
metadata = torch.empty(
|
|
[B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device
|
|
)
|
|
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
|
|
grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k)
|
|
|
|
num_warps = 2
|
|
split_size = (Mk + split_k - 1) // split_k
|
|
use_seq_len = seq_len is not None
|
|
_fwd_kernel_splitK_unrolled = unroll_varargs(
|
|
_fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1
|
|
)
|
|
|
|
_fwd_kernel_splitK_unrolled[grid](
|
|
Q=q,
|
|
K=k,
|
|
V=v,
|
|
sm_scale=inp.scale_float,
|
|
Out_splitK=o_splitk,
|
|
Metadata=metadata,
|
|
Seq_len=seq_len,
|
|
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
|
|
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
|
|
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
|
|
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
|
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
|
Z=B,
|
|
H=H,
|
|
G=G,
|
|
N_CTX_Q=M,
|
|
N_CTX_K=Mk,
|
|
BLOCK_N_PER_SPLIT=split_size,
|
|
BLOCK_M=BLOCK_M,
|
|
BLOCK_N=BLOCK_N,
|
|
BLOCK_DMODEL=Lk,
|
|
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len,
|
|
USE_SEQ_LEN=use_seq_len,
|
|
num_warps=num_warps,
|
|
num_stages=1,
|
|
PACKED_PER_VAL=PACKED_PER_VAL,
|
|
N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1,
|
|
)
|
|
|
|
if mqa_swap_seqlen_head:
|
|
out = torch.empty(
|
|
(B, H, G, M, Kq), device=q.device, dtype=q.dtype
|
|
).transpose(1, 3)
|
|
else:
|
|
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
|
|
|
|
# Merge together
|
|
grid = (B * G * H, M, 1)
|
|
_splitK_reduce[grid](
|
|
o_splitk,
|
|
metadata,
|
|
out,
|
|
lse,
|
|
split_k=split_k,
|
|
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
|
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
|
**_strides(out, "oz", "om", "og", "oh", "ok"),
|
|
**_strides(lse, "lse_zhg", "lse_m"),
|
|
BLOCK_SIZE=out.shape[-1],
|
|
G=G,
|
|
H=H,
|
|
# TODO: Tune num_warps
|
|
)
|
|
lse = lse.reshape([B, G, H, M])
|
|
if mqa_swap_seqlen_head:
|
|
# H/M dimensions have been swapped
|
|
out = out.transpose(1, 3)
|
|
lse = lse.transpose(2, 3)
|
|
if inp.query.ndim == 4:
|
|
# BMGHK -> BMHK
|
|
assert G == 1
|
|
out = out[:, :, 0]
|
|
lse = lse[:, 0]
|
|
|
|
return out, Context(out=out, lse=lse)
|
|
|
|
|
|
class FwOp_S1(FwOp):
|
|
SPLIT_K = 1
|
|
NAME = "triton_splitK1"
|
|
|
|
|
|
class FwOp_S2(FwOp):
|
|
SPLIT_K = 2
|
|
NAME = "triton_splitK2"
|
|
|
|
|
|
class FwOp_S4(FwOp):
|
|
SPLIT_K = 4
|
|
NAME = "triton_splitK4"
|
|
|
|
|
|
class FwOp_S8(FwOp):
|
|
SPLIT_K = 8
|
|
NAME = "triton_splitK8"
|
|
|
|
|
|
class FwOp_S16(FwOp):
|
|
SPLIT_K = 16
|
|
NAME = "triton_splitK16"
|
|
|
|
|
|
class FwOp_S32(FwOp):
|
|
SPLIT_K = 32
|
|
NAME = "triton_splitK32"
|
|
|
|
|
|
class FwOp_S64(FwOp):
|
|
SPLIT_K = 64
|
|
NAME = "triton_splitK64"
|
|
|
|
|
|
class FwOp_S128(FwOp):
|
|
SPLIT_K = 128
|
|
NAME = "triton_splitK128"
|