First commit
This commit is contained in:
738
pkgs/xformers/ops/fmha/triton_splitk.py
Normal file
738
pkgs/xformers/ops/fmha/triton_splitk.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..common import _has_triton21, register_operator
|
||||
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
|
||||
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
|
||||
|
||||
|
||||
def _strides(x: torch.Tensor, *stride_names: str):
|
||||
assert x.ndim == len(stride_names)
|
||||
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
|
||||
|
||||
|
||||
if TYPE_CHECKING or _has_triton21():
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_splitK(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
sm_scale,
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
Seq_len,
|
||||
stride_qz,
|
||||
stride_qm,
|
||||
stride_qg,
|
||||
stride_qh,
|
||||
stride_qk,
|
||||
stride_kz,
|
||||
stride_kn,
|
||||
stride_kg,
|
||||
stride_kh,
|
||||
stride_kk,
|
||||
stride_vz,
|
||||
stride_vn,
|
||||
stride_vg,
|
||||
stride_vh,
|
||||
stride_vk,
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_k,
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm,
|
||||
Z,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
BLOCK_N_PER_SPLIT,
|
||||
H: tl.constexpr,
|
||||
G: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
USE_SEQ_LEN: tl.constexpr,
|
||||
PACKED_PER_VAL: tl.constexpr = 1,
|
||||
N_GROUPS: tl.constexpr = 1,
|
||||
):
|
||||
"""This kernel can accept non-quantized or int4-quantized keys/values.
|
||||
PACKED_PER_VAL determines the quantization type:
|
||||
- PACKED_PER_VAL == 1 means no quantization
|
||||
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
|
||||
For the quantized case K/V should be int32 tensors.
|
||||
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
|
||||
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
|
||||
So K[B, H, M, :] has a form
|
||||
[ quant_coef0, quant_coef1, ...|
|
||||
group0_quant_value0, group0_quant_value1,... |
|
||||
group1_quant_value0, group1_quant_value1,...]
|
||||
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
|
||||
|
||||
Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs
|
||||
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
|
||||
See how FwOp.apply does it below.
|
||||
"""
|
||||
tl.static_assert(
|
||||
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
|
||||
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
|
||||
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
|
||||
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
|
||||
)
|
||||
tl.static_assert(
|
||||
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
|
||||
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
|
||||
)
|
||||
|
||||
QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
|
||||
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
|
||||
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
|
||||
|
||||
start_m = tl.program_id(0)
|
||||
off_zhg = tl.program_id(1)
|
||||
off_z = off_zhg // (H * G)
|
||||
off_h = (off_zhg // G) % H
|
||||
off_g = off_zhg % G
|
||||
splitk_idx = tl.program_id(2)
|
||||
|
||||
lo = splitk_idx * BLOCK_N_PER_SPLIT
|
||||
if USE_SEQ_LEN:
|
||||
kv_len = tl.load(Seq_len + off_z)
|
||||
else:
|
||||
kv_len = N_CTX_K
|
||||
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg,
|
||||
shape=(N_CTX_Q, D_PER_GROUP),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg
|
||||
# Additional shift by 1 along the last dimension in the quantized case, since
|
||||
# the first element along that dim contains packed quantization coefficients.
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=k_base + stride_kk * QUANTIZED * N_GROUPS,
|
||||
shape=(PACKED_D_PER_GROUP, hi),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, lo),
|
||||
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=v_base + stride_vk * QUANTIZED * N_GROUPS,
|
||||
shape=(hi, PACKED_D_PER_GROUP),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(lo, 0),
|
||||
block_shape=(BLOCK_N, PACKED_D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
if QUANTIZED:
|
||||
# Pointers to quantization coefficients. Even those they are 1D,
|
||||
# we have to use block pointers, since usual pointers
|
||||
# don't support boundary checks
|
||||
K_scale_shift_block_ptr = tl.make_block_ptr(
|
||||
base=k_base,
|
||||
shape=(1, hi),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, lo),
|
||||
block_shape=(1, BLOCK_N),
|
||||
order=(0, 1),
|
||||
)
|
||||
V_scale_shift_block_ptr = tl.make_block_ptr(
|
||||
base=v_base,
|
||||
shape=(hi, 1),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(lo, 0),
|
||||
block_shape=(BLOCK_N, 1),
|
||||
order=(1, 0),
|
||||
)
|
||||
else:
|
||||
K_scale_shift_block_ptr = None
|
||||
V_scale_shift_block_ptr = None
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
|
||||
# Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs.
|
||||
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
|
||||
# This is a solution for Triton native lack of support for lists of tensors.
|
||||
acc: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
q[i] = tl.load( # noqa: F821
|
||||
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
|
||||
)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
k: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
v: "VAR_ARGS_ARRAY" # noqa: F821
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
k[i], v[i] = load_dequantize_k_v_group( # noqa: F821
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr,
|
||||
V_scale_shift_block_ptr,
|
||||
BOUNDS_CHECKS_N,
|
||||
PACKED_PER_VAL,
|
||||
PACKED_D_PER_GROUP,
|
||||
Q.dtype.element_ty,
|
||||
i,
|
||||
)
|
||||
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
qk += tl.dot(q[i], k[i]) # noqa: F821
|
||||
qk *= qk_scale
|
||||
|
||||
# TODO: This is slow, and only needed at the last iteration.
|
||||
# Maybe we can unroll the last iteration instead?
|
||||
if BOUNDS_CHECKS_N:
|
||||
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
p = tl.math.exp2(qk - m_i_new[:, None])
|
||||
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
m_i = m_i_new
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
|
||||
# -- scale and update acc --
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
acc[i] *= alpha[:, None] # noqa: F821
|
||||
acc[i] += tl.dot(p, v[i]) # noqa: F821
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
if PACKED_PER_VAL > 1:
|
||||
K_scale_shift_block_ptr = tl.advance(
|
||||
K_scale_shift_block_ptr, (0, BLOCK_N)
|
||||
)
|
||||
V_scale_shift_block_ptr = tl.advance(
|
||||
V_scale_shift_block_ptr, (BLOCK_N, 0)
|
||||
)
|
||||
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
|
||||
shape=(N_CTX_Q, D_PER_GROUP),
|
||||
strides=(stride_osk_m, 1),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, D_PER_GROUP),
|
||||
order=(1, 0),
|
||||
)
|
||||
for i in range(len(acc)): # noqa: F821
|
||||
tl.store(
|
||||
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
|
||||
acc[i], # noqa: F821
|
||||
boundary_check=(0,),
|
||||
)
|
||||
# Write metadata for split-K reduction
|
||||
Metadata_ptr = (
|
||||
Metadata
|
||||
+ off_zhg * stride_mzhg
|
||||
+ splitk_idx * stride_ms
|
||||
+ start_m * BLOCK_M
|
||||
+ tl.arange(0, BLOCK_M)
|
||||
)
|
||||
tl.store(Metadata_ptr, m_i)
|
||||
tl.store(Metadata_ptr + stride_m2, l_i)
|
||||
|
||||
@triton.jit
|
||||
def load_dequantize_k_v_group(
|
||||
K_block_ptr,
|
||||
V_block_ptr,
|
||||
K_scale_shift_block_ptr,
|
||||
V_scale_shift_block_ptr,
|
||||
BOUNDS_CHECKS_N: tl.constexpr,
|
||||
PACKED_PER_VAL: tl.constexpr,
|
||||
PACKED_D_PER_GROUP: tl.constexpr,
|
||||
dtype: tl.constexpr,
|
||||
group_id: tl.constexpr,
|
||||
):
|
||||
"""Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading.
|
||||
If quantization is group-wise, use group_id to advance the pointers to the current group.
|
||||
"""
|
||||
# Advance to the current quantization group
|
||||
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id))
|
||||
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
|
||||
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
|
||||
|
||||
if PACKED_PER_VAL > 1:
|
||||
# K/V are quantized, load quantization coefficients and dequantize
|
||||
|
||||
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
|
||||
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
|
||||
|
||||
k_scale_shift = tl.load(
|
||||
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
|
||||
)
|
||||
v_scale_shift = tl.load(
|
||||
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
|
||||
)
|
||||
|
||||
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
|
||||
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
|
||||
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype)
|
||||
k_t = dequantize(
|
||||
tl.trans(k),
|
||||
tl.trans(k_scale),
|
||||
tl.trans(k_shift),
|
||||
PACKED_PER_VAL,
|
||||
).to(dtype)
|
||||
k = tl.trans(k_t)
|
||||
return k, v
|
||||
|
||||
@triton.jit
|
||||
def cast_uint32_to_half2(scale_shift):
|
||||
"""Extract two float16 packed into one int32"""
|
||||
scale = scale_shift & 0xFFFF
|
||||
shift = scale_shift >> 16
|
||||
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
return scale, shift
|
||||
|
||||
@triton.jit
|
||||
def dequantize(
|
||||
x_,
|
||||
scale,
|
||||
shift,
|
||||
PACKED_PER_VAL: tl.constexpr = 8,
|
||||
):
|
||||
"""PACKED_PER_VAL is the number of values packed into each element x_.
|
||||
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
|
||||
"""
|
||||
|
||||
# Axis along which offsets are applied matters here
|
||||
# It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
|
||||
# and expand K/V to that shape before applying offsets
|
||||
# However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2
|
||||
# Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends
|
||||
# on the implementation details which might change in the future.
|
||||
# Ideally we would like to use tl.reshape, but it's not implemented yet.
|
||||
# See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501
|
||||
|
||||
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
|
||||
# scale: (BLOCK_N, 1)
|
||||
# offsets: (PACKED_PER_VAL,)
|
||||
BLOCK_N: tl.constexpr = x_.shape[0]
|
||||
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
|
||||
offsets = tl.arange(0, PACKED_PER_VAL) * 4
|
||||
quant_offset = (
|
||||
x_[:, None, :] >> offsets[None, :, None]
|
||||
) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
|
||||
|
||||
quant_offset = tl.view(
|
||||
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
|
||||
)
|
||||
# Trick - instead of converting int4 to float16 we view it as float16
|
||||
# and then multiply by 32768 * 512 == 2**24
|
||||
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
|
||||
quant_offset = (quant_offset * 32768.0).to(tl.float16)
|
||||
scale_512 = scale * 512
|
||||
|
||||
dequant = quant_offset * scale_512 + shift
|
||||
return dequant
|
||||
|
||||
@triton.jit
|
||||
def _splitK_reduce(
|
||||
Out_splitK, # [B, H, split_k, Mq, K]
|
||||
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
|
||||
Out, # [B, H, M, K]
|
||||
LSE, # [B, H, M]
|
||||
split_k,
|
||||
stride_osk_zhg,
|
||||
stride_osk_s,
|
||||
stride_osk_m,
|
||||
stride_osk_k,
|
||||
stride_mzhg,
|
||||
stride_m2,
|
||||
stride_ms,
|
||||
stride_mm,
|
||||
stride_oz,
|
||||
stride_oh,
|
||||
stride_og,
|
||||
stride_om,
|
||||
stride_ok,
|
||||
stride_lse_zhg,
|
||||
stride_lse_m,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
G: tl.constexpr,
|
||||
):
|
||||
off_zhg = tl.program_id(0)
|
||||
off_z = off_zhg // (H * G)
|
||||
off_h = (off_zhg // G) % H
|
||||
off_g = off_zhg % G
|
||||
off_m = tl.program_id(1)
|
||||
|
||||
Out_splitK_ptr = (
|
||||
Out_splitK
|
||||
+ stride_osk_zhg * off_zhg
|
||||
+ stride_osk_m * off_m
|
||||
+ tl.arange(0, BLOCK_SIZE)
|
||||
)
|
||||
Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m
|
||||
m = tl.load(Metadata_ptr)
|
||||
l_sum = tl.load(Metadata_ptr + stride_m2)
|
||||
acc = tl.load(Out_splitK_ptr)
|
||||
|
||||
for split_k_idx in range(1, split_k):
|
||||
Metadata_ptr = Metadata_ptr + stride_ms
|
||||
Out_splitK_ptr = Out_splitK_ptr + stride_osk_s
|
||||
|
||||
m_k = tl.load(Metadata_ptr)
|
||||
l_k = tl.load(Metadata_ptr + stride_m2)
|
||||
acc_k = tl.load(Out_splitK_ptr)
|
||||
|
||||
m_new = tl.maximum(m, m_k)
|
||||
if m_k < m:
|
||||
# Scale incoming values
|
||||
alpha = tl.math.exp2(m_k - m_new)
|
||||
acc_k = acc_k * alpha
|
||||
l_k = l_k * alpha
|
||||
else:
|
||||
# Scale our values
|
||||
alpha = tl.math.exp2(m - m_new)
|
||||
acc = acc * alpha
|
||||
l_sum = l_sum * alpha
|
||||
|
||||
m = m_new
|
||||
l_sum = l_sum + l_k
|
||||
acc = acc + acc_k
|
||||
|
||||
acc = acc / l_sum
|
||||
Out_ptr = (
|
||||
Out
|
||||
+ stride_oz * off_z
|
||||
+ stride_oh * off_h
|
||||
+ stride_og * off_g
|
||||
+ stride_om * off_m
|
||||
+ tl.arange(0, BLOCK_SIZE)
|
||||
)
|
||||
tl.store(Out_ptr, acc)
|
||||
|
||||
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
|
||||
tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504)
|
||||
|
||||
else:
|
||||
_fwd_kernel_splitK = None
|
||||
_splitK_reduce = None
|
||||
|
||||
|
||||
@register_operator
|
||||
class FwOp(AttentionFwOpBase):
|
||||
"""Flash-Attention with Split-K. Supports fused int-4 K/V quantization.
|
||||
Quantized path will be taken if input K/V have type int32.
|
||||
|
||||
Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
|
||||
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
|
||||
Quantization coefficients (scale and shift) are represented as two
|
||||
float16 constants per group, packed into int32. Quantization coefficients of
|
||||
all groups are placed at the beginning of the row. So, if unquantized K/V have head
|
||||
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
|
||||
and dtype int32.
|
||||
Pseudocode for dequantizing one row can look like:
|
||||
group_size = D // 8
|
||||
for i in range(NUM_GROUPS):
|
||||
group_start = NUM_GROUPS + i * group_size
|
||||
group_quant = K[..., group_start: group_start + group_size]
|
||||
scale, shift = unpack_int32_into_float16x2(group_quant[0])
|
||||
group_dequant = group_quant[..., 1:] * scale + shift
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
OPERATOR = _fwd_kernel_splitK
|
||||
SUPPORTED_DEVICES = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES = {
|
||||
torch.half,
|
||||
torch.bfloat16,
|
||||
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
|
||||
SUPPORTED_MAX_K = 128
|
||||
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
|
||||
type(None),
|
||||
BlockDiagonalCausalWithOffsetPaddedKeysMask,
|
||||
}
|
||||
SUPPORTS_DROPOUT = False
|
||||
SUPPORTS_CUSTOM_SCALE = True
|
||||
SUPPORTS_BMGHK = True
|
||||
NAME = "triton_splitKF"
|
||||
|
||||
SPLIT_K: Optional[int] = None
|
||||
BLOCK_M = 16
|
||||
BLOCK_N = 64
|
||||
|
||||
NUM_GROUPS = 1 # Default quantization is row-wise
|
||||
|
||||
@classmethod
|
||||
def shape_not_supported_reasons(
|
||||
cls, Mq: int, Mkv: int, K: int, Kv: int
|
||||
) -> List[str]:
|
||||
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
|
||||
if K not in {16, 32, 64, 128}:
|
||||
reasons.append(f"Embed dim {K} not supported")
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def not_supported_reasons(cls, d: Inputs) -> List[str]:
|
||||
reasons = super(FwOp, cls).not_supported_reasons(d)
|
||||
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
|
||||
if d.key.dtype != torch.int32:
|
||||
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
|
||||
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
|
||||
if cls.OPERATOR is None:
|
||||
reasons.append("triton is not available")
|
||||
if d.device.type == "cuda":
|
||||
# Has only been tested on 8.0 / 9.0.
|
||||
if torch.cuda.get_device_capability(d.device) < (8, 0):
|
||||
reasons.append(
|
||||
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
|
||||
)
|
||||
|
||||
q_len = d.query.shape[1]
|
||||
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
|
||||
seqinfo = d.attn_bias.q_seqinfo
|
||||
if q_len != seqinfo.seqstart_py[-1]:
|
||||
reasons.append(
|
||||
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
|
||||
)
|
||||
q_len = seqinfo.min_seqlen
|
||||
if q_len != seqinfo.max_seqlen:
|
||||
reasons.append(
|
||||
"Variable query len is not supported in the presence of causal mask."
|
||||
)
|
||||
|
||||
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
|
||||
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
|
||||
reasons.append("multiquery is only supported with query seqlen=1")
|
||||
|
||||
if d.attn_bias is not None and q_len > 1:
|
||||
reasons.append(
|
||||
"query with seqlen > 1 is not supported in the presence of causal mask"
|
||||
)
|
||||
return reasons
|
||||
|
||||
@classmethod
|
||||
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
|
||||
"""Heuristic for the number of splits"""
|
||||
bh = B * H
|
||||
split_k = max(Mk, 1024) // bh
|
||||
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
|
||||
while split_k > 0 and Mk / split_k < max_chunk_size:
|
||||
split_k = split_k // 2
|
||||
split_k = min(split_k, 64)
|
||||
split_k = max(split_k, 1)
|
||||
return split_k
|
||||
|
||||
@classmethod
|
||||
def apply(
|
||||
cls, inp: Inputs, needs_gradient: bool
|
||||
) -> Tuple[torch.Tensor, Optional[Context]]:
|
||||
attn_bias = inp.attn_bias
|
||||
seq_len = None
|
||||
q, k, v = inp.get_qkv_in_bmghk()
|
||||
|
||||
if attn_bias is not None:
|
||||
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
|
||||
# TODO: do we really need to do this cast? seems fishy but
|
||||
# I just copied it from the decoder.py
|
||||
attn_bias.k_seqinfo.to(inp.query.device)
|
||||
attn_bias.q_seqinfo.to(inp.query.device)
|
||||
seq_len = attn_bias.k_seqinfo.seqlen
|
||||
B = len(seq_len)
|
||||
G, H, Kq = q.shape[-3:]
|
||||
Kkv = v.shape[-1]
|
||||
|
||||
# assume kv has been padded
|
||||
q = q.reshape(B, -1, G, H, Kq)
|
||||
k = k.reshape(B, -1, G, H, Kkv)
|
||||
v = v.reshape(B, -1, G, H, Kkv)
|
||||
|
||||
# Transpose in the case of MQA/GQA
|
||||
mqa_swap_seqlen_head = False
|
||||
if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0:
|
||||
mqa_swap_seqlen_head = True
|
||||
assert q.shape[1] == 1
|
||||
q = q.transpose(1, 3)
|
||||
k = k[:, :, :, :1]
|
||||
v = v[:, :, :, :1]
|
||||
|
||||
if k.dtype == torch.int32:
|
||||
# Quantized K/V
|
||||
PACKED_PER_VAL = 8
|
||||
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
|
||||
else:
|
||||
Lk = k.shape[-1]
|
||||
PACKED_PER_VAL = 1
|
||||
|
||||
B, Mk, G, H, Kkv = k.shape
|
||||
B, M, G, H, Kq = q.shape
|
||||
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
|
||||
|
||||
BLOCK_M = cls.BLOCK_M
|
||||
BLOCK_N = cls.BLOCK_N
|
||||
if cls.SPLIT_K is not None:
|
||||
split_k = cls.SPLIT_K
|
||||
else:
|
||||
# Use heuristics
|
||||
split_k = cls.get_split_k(B, H, Mk)
|
||||
|
||||
M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M
|
||||
o_splitk = torch.empty(
|
||||
[B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device
|
||||
)
|
||||
metadata = torch.empty(
|
||||
[B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device
|
||||
)
|
||||
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
|
||||
grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k)
|
||||
|
||||
num_warps = 2
|
||||
split_size = (Mk + split_k - 1) // split_k
|
||||
use_seq_len = seq_len is not None
|
||||
_fwd_kernel_splitK_unrolled = unroll_varargs(
|
||||
_fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1
|
||||
)
|
||||
|
||||
_fwd_kernel_splitK_unrolled[grid](
|
||||
Q=q,
|
||||
K=k,
|
||||
V=v,
|
||||
sm_scale=inp.scale_float,
|
||||
Out_splitK=o_splitk,
|
||||
Metadata=metadata,
|
||||
Seq_len=seq_len,
|
||||
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
|
||||
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
|
||||
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
|
||||
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
Z=B,
|
||||
H=H,
|
||||
G=G,
|
||||
N_CTX_Q=M,
|
||||
N_CTX_K=Mk,
|
||||
BLOCK_N_PER_SPLIT=split_size,
|
||||
BLOCK_M=BLOCK_M,
|
||||
BLOCK_N=BLOCK_N,
|
||||
BLOCK_DMODEL=Lk,
|
||||
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len,
|
||||
USE_SEQ_LEN=use_seq_len,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
PACKED_PER_VAL=PACKED_PER_VAL,
|
||||
N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1,
|
||||
)
|
||||
|
||||
if mqa_swap_seqlen_head:
|
||||
out = torch.empty(
|
||||
(B, H, G, M, Kq), device=q.device, dtype=q.dtype
|
||||
).transpose(1, 3)
|
||||
else:
|
||||
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
|
||||
|
||||
# Merge together
|
||||
grid = (B * G * H, M, 1)
|
||||
_splitK_reduce[grid](
|
||||
o_splitk,
|
||||
metadata,
|
||||
out,
|
||||
lse,
|
||||
split_k=split_k,
|
||||
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
|
||||
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
|
||||
**_strides(out, "oz", "om", "og", "oh", "ok"),
|
||||
**_strides(lse, "lse_zhg", "lse_m"),
|
||||
BLOCK_SIZE=out.shape[-1],
|
||||
G=G,
|
||||
H=H,
|
||||
# TODO: Tune num_warps
|
||||
)
|
||||
lse = lse.reshape([B, G, H, M])
|
||||
if mqa_swap_seqlen_head:
|
||||
# H/M dimensions have been swapped
|
||||
out = out.transpose(1, 3)
|
||||
lse = lse.transpose(2, 3)
|
||||
if inp.query.ndim == 4:
|
||||
# BMGHK -> BMHK
|
||||
assert G == 1
|
||||
out = out[:, :, 0]
|
||||
lse = lse[:, 0]
|
||||
|
||||
return out, Context(out=out, lse=lse)
|
||||
|
||||
|
||||
class FwOp_S1(FwOp):
|
||||
SPLIT_K = 1
|
||||
NAME = "triton_splitK1"
|
||||
|
||||
|
||||
class FwOp_S2(FwOp):
|
||||
SPLIT_K = 2
|
||||
NAME = "triton_splitK2"
|
||||
|
||||
|
||||
class FwOp_S4(FwOp):
|
||||
SPLIT_K = 4
|
||||
NAME = "triton_splitK4"
|
||||
|
||||
|
||||
class FwOp_S8(FwOp):
|
||||
SPLIT_K = 8
|
||||
NAME = "triton_splitK8"
|
||||
|
||||
|
||||
class FwOp_S16(FwOp):
|
||||
SPLIT_K = 16
|
||||
NAME = "triton_splitK16"
|
||||
|
||||
|
||||
class FwOp_S32(FwOp):
|
||||
SPLIT_K = 32
|
||||
NAME = "triton_splitK32"
|
||||
|
||||
|
||||
class FwOp_S64(FwOp):
|
||||
SPLIT_K = 64
|
||||
NAME = "triton_splitK64"
|
||||
|
||||
|
||||
class FwOp_S128(FwOp):
|
||||
SPLIT_K = 128
|
||||
NAME = "triton_splitK128"
|
||||
Reference in New Issue
Block a user