Files
enginex-bi_series-vllm/pkgs/xformers/ops/fmha/triton_splitk.py
2025-08-05 19:02:46 +08:00

739 lines
26 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple
import torch
from ..common import _has_triton21, register_operator
from .attn_bias import BlockDiagonalCausalWithOffsetPaddedKeysMask
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
def _strides(x: torch.Tensor, *stride_names: str):
assert x.ndim == len(stride_names)
return {f"stride_{s}": x.stride(i) for i, s in enumerate(stride_names)}
if TYPE_CHECKING or _has_triton21():
import triton
import triton.language as tl
from xformers.triton.vararg_kernel import VAR_ARGS_ARRAY, unroll_varargs
@triton.jit
def _fwd_kernel_splitK(
Q,
K,
V,
sm_scale,
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Seq_len,
stride_qz,
stride_qm,
stride_qg,
stride_qh,
stride_qk,
stride_kz,
stride_kn,
stride_kg,
stride_kh,
stride_kk,
stride_vz,
stride_vn,
stride_vg,
stride_vh,
stride_vk,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
Z,
N_CTX_Q,
N_CTX_K,
BLOCK_N_PER_SPLIT,
H: tl.constexpr,
G: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
BOUNDS_CHECKS_N: tl.constexpr,
USE_SEQ_LEN: tl.constexpr,
PACKED_PER_VAL: tl.constexpr = 1,
N_GROUPS: tl.constexpr = 1,
):
"""This kernel can accept non-quantized or int4-quantized keys/values.
PACKED_PER_VAL determines the quantization type:
- PACKED_PER_VAL == 1 means no quantization
- PACKED_PER_VAL == 8 means 4-bit quantization (8 packed quantized values inside one int32)
For the quantized case K/V should be int32 tensors.
Quantization can be row-wise (when N_GROUPS = 1) or group-wise with N_GROUPS = 2, 4, or 8.
Quantization coefficients are stored at the beginning of the row along the last dimension of K/V
So K[B, H, M, :] has a form
[ quant_coef0, quant_coef1, ...|
group0_quant_value0, group0_quant_value1,... |
group1_quant_value0, group1_quant_value1,...]
where each quant_coef is an int32 which should be interpreted as 2 packed float16: scale and offset.
Note: this kernel needs to be processed by xformers.triton.vararg_kernel.unroll_varargs
before compilation. That will unroll variables marked with "VAR_ARGS_ARRAY" into lists.
See how FwOp.apply does it below.
"""
tl.static_assert(
(PACKED_PER_VAL == 1 and tl.constexpr(K.dtype.element_ty != tl.int32))
or (PACKED_PER_VAL == 8 and tl.constexpr(K.dtype.element_ty == tl.int32)),
f"Only 4-bit quantization is supported, K/V should have dtype int32 in "
f"the quantized case: {PACKED_PER_VAL=} {tl.constexpr(K.dtype)=} {tl.constexpr(K.dtype.element_ty)=}",
)
tl.static_assert(
(((N_GROUPS == 1 or N_GROUPS == 2) or N_GROUPS == 4) or N_GROUPS == 8),
"Number of quantization groups can be 1 (row-wise quantization), 2, 4, or 8.",
)
QUANTIZED: tl.constexpr = PACKED_PER_VAL > 1
PACKED_D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // PACKED_PER_VAL // N_GROUPS
D_PER_GROUP: tl.constexpr = BLOCK_DMODEL // N_GROUPS
start_m = tl.program_id(0)
off_zhg = tl.program_id(1)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
splitk_idx = tl.program_id(2)
lo = splitk_idx * BLOCK_N_PER_SPLIT
if USE_SEQ_LEN:
kv_len = tl.load(Seq_len + off_z)
else:
kv_len = N_CTX_K
hi = tl.minimum((splitk_idx + 1) * BLOCK_N_PER_SPLIT, kv_len)
Q_block_ptr = tl.make_block_ptr(
base=Q + off_h * stride_qh + off_z * stride_qz + off_g * stride_qg,
shape=(N_CTX_Q, D_PER_GROUP),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, D_PER_GROUP),
order=(1, 0),
)
k_base = K + off_h * stride_kh + off_z * stride_kz + off_g * stride_kg
# Additional shift by 1 along the last dimension in the quantized case, since
# the first element along that dim contains packed quantization coefficients.
K_block_ptr = tl.make_block_ptr(
base=k_base + stride_kk * QUANTIZED * N_GROUPS,
shape=(PACKED_D_PER_GROUP, hi),
strides=(stride_kk, stride_kn),
offsets=(0, lo),
block_shape=(PACKED_D_PER_GROUP, BLOCK_N),
order=(0, 1),
)
v_base = V + off_h * stride_vh + off_z * stride_vz + off_g * stride_vg
V_block_ptr = tl.make_block_ptr(
base=v_base + stride_vk * QUANTIZED * N_GROUPS,
shape=(hi, PACKED_D_PER_GROUP),
strides=(stride_vn, stride_vk),
offsets=(lo, 0),
block_shape=(BLOCK_N, PACKED_D_PER_GROUP),
order=(1, 0),
)
if QUANTIZED:
# Pointers to quantization coefficients. Even those they are 1D,
# we have to use block pointers, since usual pointers
# don't support boundary checks
K_scale_shift_block_ptr = tl.make_block_ptr(
base=k_base,
shape=(1, hi),
strides=(stride_kk, stride_kn),
offsets=(0, lo),
block_shape=(1, BLOCK_N),
order=(0, 1),
)
V_scale_shift_block_ptr = tl.make_block_ptr(
base=v_base,
shape=(hi, 1),
strides=(stride_vn, stride_vk),
offsets=(lo, 0),
block_shape=(BLOCK_N, 1),
order=(1, 0),
)
else:
K_scale_shift_block_ptr = None
V_scale_shift_block_ptr = None
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
# Before compilation, this kernel will be processed by xformers.triton.vararg_kernel.unroll_varargs.
# That turns tensors annotated as the one below into lists of tensors of length N_GROUPS.
# This is a solution for Triton native lack of support for lists of tensors.
acc: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
acc[i] = tl.zeros([BLOCK_M, D_PER_GROUP], dtype=tl.float32) # noqa: F821
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
q[i] = tl.load( # noqa: F821
tl.advance(Q_block_ptr, (0, i * D_PER_GROUP)), boundary_check=(0,)
)
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
k: "VAR_ARGS_ARRAY" # noqa: F821
v: "VAR_ARGS_ARRAY" # noqa: F821
for i in range(len(acc)): # noqa: F821
k[i], v[i] = load_dequantize_k_v_group( # noqa: F821
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N,
PACKED_PER_VAL,
PACKED_D_PER_GROUP,
Q.dtype.element_ty,
i,
)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
for i in range(len(acc)): # noqa: F821
qk += tl.dot(q[i], k[i]) # noqa: F821
qk *= qk_scale
# TODO: This is slow, and only needed at the last iteration.
# Maybe we can unroll the last iteration instead?
if BOUNDS_CHECKS_N:
qk = tl.where(tl.arange(0, BLOCK_N) < hi - start_n, qk, float("-inf"))
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
p = p.to(Q.dtype.element_ty)
# -- scale and update acc --
for i in range(len(acc)): # noqa: F821
acc[i] *= alpha[:, None] # noqa: F821
acc[i] += tl.dot(p, v[i]) # noqa: F821
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
if PACKED_PER_VAL > 1:
K_scale_shift_block_ptr = tl.advance(
K_scale_shift_block_ptr, (0, BLOCK_N)
)
V_scale_shift_block_ptr = tl.advance(
V_scale_shift_block_ptr, (BLOCK_N, 0)
)
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out_splitK + off_zhg * stride_osk_zhg + splitk_idx * stride_osk_s,
shape=(N_CTX_Q, D_PER_GROUP),
strides=(stride_osk_m, 1),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, D_PER_GROUP),
order=(1, 0),
)
for i in range(len(acc)): # noqa: F821
tl.store(
tl.advance(O_block_ptr, (0, i * D_PER_GROUP)),
acc[i], # noqa: F821
boundary_check=(0,),
)
# Write metadata for split-K reduction
Metadata_ptr = (
Metadata
+ off_zhg * stride_mzhg
+ splitk_idx * stride_ms
+ start_m * BLOCK_M
+ tl.arange(0, BLOCK_M)
)
tl.store(Metadata_ptr, m_i)
tl.store(Metadata_ptr + stride_m2, l_i)
@triton.jit
def load_dequantize_k_v_group(
K_block_ptr,
V_block_ptr,
K_scale_shift_block_ptr,
V_scale_shift_block_ptr,
BOUNDS_CHECKS_N: tl.constexpr,
PACKED_PER_VAL: tl.constexpr,
PACKED_D_PER_GROUP: tl.constexpr,
dtype: tl.constexpr,
group_id: tl.constexpr,
):
"""Load K/V for a given block. In case of int4-quantized K/V, dequantize them after loading.
If quantization is group-wise, use group_id to advance the pointers to the current group.
"""
# Advance to the current quantization group
K_block_ptr = tl.advance(K_block_ptr, (PACKED_D_PER_GROUP * group_id, 0))
V_block_ptr = tl.advance(V_block_ptr, (0, PACKED_D_PER_GROUP * group_id))
# -- load k, v --
k = tl.load(K_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ())
v = tl.load(V_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ())
if PACKED_PER_VAL > 1:
# K/V are quantized, load quantization coefficients and dequantize
K_scale_shift_block_ptr = tl.advance(K_scale_shift_block_ptr, (group_id, 0))
V_scale_shift_block_ptr = tl.advance(V_scale_shift_block_ptr, (0, group_id))
k_scale_shift = tl.load(
K_scale_shift_block_ptr, boundary_check=(1,) if BOUNDS_CHECKS_N else ()
)
v_scale_shift = tl.load(
V_scale_shift_block_ptr, boundary_check=(0,) if BOUNDS_CHECKS_N else ()
)
k_scale, k_shift = cast_uint32_to_half2(k_scale_shift)
v_scale, v_shift = cast_uint32_to_half2(v_scale_shift)
v = dequantize(v, v_scale, v_shift, PACKED_PER_VAL).to(dtype)
k_t = dequantize(
tl.trans(k),
tl.trans(k_scale),
tl.trans(k_shift),
PACKED_PER_VAL,
).to(dtype)
k = tl.trans(k_t)
return k, v
@triton.jit
def cast_uint32_to_half2(scale_shift):
"""Extract two float16 packed into one int32"""
scale = scale_shift & 0xFFFF
shift = scale_shift >> 16
scale = scale.to(tl.uint16).to(tl.float16, bitcast=True)
shift = shift.to(tl.uint16).to(tl.float16, bitcast=True)
return scale, shift
@triton.jit
def dequantize(
x_,
scale,
shift,
PACKED_PER_VAL: tl.constexpr = 8,
):
"""PACKED_PER_VAL is the number of values packed into each element x_.
For example, for int4 quantization and x_ of type int32, PACKED_PER_VAL is 8.
"""
# Axis along which offsets are applied matters here
# It would be natural to have offsets in shape (BLOCK_N, D // PACKED_PER_VAL, PACKED_PER_VAL)
# and expand K/V to that shape before applying offsets
# However, Triton for some reason considers dim=1 as contiguous when doing tl.view below, and not dim=2
# Note that tl.view doesn't guarantee the order of elements in the result - thus the code below depends
# on the implementation details which might change in the future.
# Ideally we would like to use tl.reshape, but it's not implemented yet.
# See https://github.com/openai/triton/blob/9055af1a5dadc576804b38dd77ee91dc42af0bf7/python/triton/language/semantic.py#L541 # noqa: E501
# x_ : (BLOCK_N, D // PACKED_PER_VAL)
# scale: (BLOCK_N, 1)
# offsets: (PACKED_PER_VAL,)
BLOCK_N: tl.constexpr = x_.shape[0]
BLOCK_DMODEL_PACKED: tl.constexpr = x_.shape[1]
offsets = tl.arange(0, PACKED_PER_VAL) * 4
quant_offset = (
x_[:, None, :] >> offsets[None, :, None]
) # (BLOCK_N, PACKED_PER_VAL, D // PACKED_PER_VAL)
quant_offset = tl.view(
quant_offset, (BLOCK_N, BLOCK_DMODEL_PACKED * PACKED_PER_VAL)
)
# Trick - instead of converting int4 to float16 we view it as float16
# and then multiply by 32768 * 512 == 2**24
quant_offset = (quant_offset & 0xF).to(tl.uint16).to(tl.float16, bitcast=True)
quant_offset = (quant_offset * 32768.0).to(tl.float16)
scale_512 = scale * 512
dequant = quant_offset * scale_512 + shift
return dequant
@triton.jit
def _splitK_reduce(
Out_splitK, # [B, H, split_k, Mq, K]
Metadata, # [B, H, 2, split_k, M_ceil] contains [mi, li]
Out, # [B, H, M, K]
LSE, # [B, H, M]
split_k,
stride_osk_zhg,
stride_osk_s,
stride_osk_m,
stride_osk_k,
stride_mzhg,
stride_m2,
stride_ms,
stride_mm,
stride_oz,
stride_oh,
stride_og,
stride_om,
stride_ok,
stride_lse_zhg,
stride_lse_m,
BLOCK_SIZE: tl.constexpr,
H: tl.constexpr,
G: tl.constexpr,
):
off_zhg = tl.program_id(0)
off_z = off_zhg // (H * G)
off_h = (off_zhg // G) % H
off_g = off_zhg % G
off_m = tl.program_id(1)
Out_splitK_ptr = (
Out_splitK
+ stride_osk_zhg * off_zhg
+ stride_osk_m * off_m
+ tl.arange(0, BLOCK_SIZE)
)
Metadata_ptr = Metadata + stride_mzhg * off_zhg + off_m
m = tl.load(Metadata_ptr)
l_sum = tl.load(Metadata_ptr + stride_m2)
acc = tl.load(Out_splitK_ptr)
for split_k_idx in range(1, split_k):
Metadata_ptr = Metadata_ptr + stride_ms
Out_splitK_ptr = Out_splitK_ptr + stride_osk_s
m_k = tl.load(Metadata_ptr)
l_k = tl.load(Metadata_ptr + stride_m2)
acc_k = tl.load(Out_splitK_ptr)
m_new = tl.maximum(m, m_k)
if m_k < m:
# Scale incoming values
alpha = tl.math.exp2(m_k - m_new)
acc_k = acc_k * alpha
l_k = l_k * alpha
else:
# Scale our values
alpha = tl.math.exp2(m - m_new)
acc = acc * alpha
l_sum = l_sum * alpha
m = m_new
l_sum = l_sum + l_k
acc = acc + acc_k
acc = acc / l_sum
Out_ptr = (
Out
+ stride_oz * off_z
+ stride_oh * off_h
+ stride_og * off_g
+ stride_om * off_m
+ tl.arange(0, BLOCK_SIZE)
)
tl.store(Out_ptr, acc)
l_ptrs = LSE + off_zhg * stride_lse_zhg + off_m
tl.store(l_ptrs, (m + tl.math.log2(l_sum)) / 1.44269504)
else:
_fwd_kernel_splitK = None
_splitK_reduce = None
@register_operator
class FwOp(AttentionFwOpBase):
"""Flash-Attention with Split-K. Supports fused int-4 K/V quantization.
Quantized path will be taken if input K/V have type int32.
Quantization can be row-wise or group-wise (when cls.NUM_GROUPS > 1) along
the last dimension of K and V. Currently 1, 2, 4, or 8 groups per row are supported.
Quantization coefficients (scale and shift) are represented as two
float16 constants per group, packed into int32. Quantization coefficients of
all groups are placed at the beginning of the row. So, if unquantized K/V have head
dimension D, the quantized versions have head dimension D // 8 + NUM_GROUPS
and dtype int32.
Pseudocode for dequantizing one row can look like:
group_size = D // 8
for i in range(NUM_GROUPS):
group_start = NUM_GROUPS + i * group_size
group_quant = K[..., group_start: group_start + group_size]
scale, shift = unpack_int32_into_float16x2(group_quant[0])
group_dequant = group_quant[..., 1:] * scale + shift
...
"""
OPERATOR = _fwd_kernel_splitK
SUPPORTED_DEVICES = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES = {
torch.half,
torch.bfloat16,
} # Those are dtypes of Q. In the quantized case K/V has dtype int32
SUPPORTED_MAX_K = 128
SUPPORTED_ATTN_BIAS_TYPES: Set[Any] = {
type(None),
BlockDiagonalCausalWithOffsetPaddedKeysMask,
}
SUPPORTS_DROPOUT = False
SUPPORTS_CUSTOM_SCALE = True
SUPPORTS_BMGHK = True
NAME = "triton_splitKF"
SPLIT_K: Optional[int] = None
BLOCK_M = 16
BLOCK_N = 64
NUM_GROUPS = 1 # Default quantization is row-wise
@classmethod
def shape_not_supported_reasons(
cls, Mq: int, Mkv: int, K: int, Kv: int
) -> List[str]:
reasons = super().shape_not_supported_reasons(Mq, Mkv, K, Kv)
if K not in {16, 32, 64, 128}:
reasons.append(f"Embed dim {K} not supported")
return reasons
@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(FwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
if d.key.dtype != torch.int32:
check_lastdim_alignment_stride1(reasons, "key", d.key, 8)
check_lastdim_alignment_stride1(reasons, "value", d.value, 8)
if cls.OPERATOR is None:
reasons.append("triton is not available")
if d.device.type == "cuda":
# Has only been tested on 8.0 / 9.0.
if torch.cuda.get_device_capability(d.device) < (8, 0):
reasons.append(
"requires GPU with sm80 minimum compute capacity, e.g., A100/H100/L4"
)
q_len = d.query.shape[1]
if isinstance(d.attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask):
seqinfo = d.attn_bias.q_seqinfo
if q_len != seqinfo.seqstart_py[-1]:
reasons.append(
f"Expected total {seqinfo.seqstart_py[-1]} queries not {q_len}"
)
q_len = seqinfo.min_seqlen
if q_len != seqinfo.max_seqlen:
reasons.append(
"Variable query len is not supported in the presence of causal mask."
)
if d.key.ndim in [4, 5] and d.key.shape[-2] != 1:
if d.key.stride(-2) == 0 and d.value.stride(-2) == 0 and q_len > 1:
reasons.append("multiquery is only supported with query seqlen=1")
if d.attn_bias is not None and q_len > 1:
reasons.append(
"query with seqlen > 1 is not supported in the presence of causal mask"
)
return reasons
@classmethod
def get_split_k(cls, B: int, H: int, Mk: int) -> int:
"""Heuristic for the number of splits"""
bh = B * H
split_k = max(Mk, 1024) // bh
max_chunk_size = 64 if Mk <= 512 and bh <= 64 else 128
while split_k > 0 and Mk / split_k < max_chunk_size:
split_k = split_k // 2
split_k = min(split_k, 64)
split_k = max(split_k, 1)
return split_k
@classmethod
def apply(
cls, inp: Inputs, needs_gradient: bool
) -> Tuple[torch.Tensor, Optional[Context]]:
attn_bias = inp.attn_bias
seq_len = None
q, k, v = inp.get_qkv_in_bmghk()
if attn_bias is not None:
assert isinstance(attn_bias, BlockDiagonalCausalWithOffsetPaddedKeysMask)
# TODO: do we really need to do this cast? seems fishy but
# I just copied it from the decoder.py
attn_bias.k_seqinfo.to(inp.query.device)
attn_bias.q_seqinfo.to(inp.query.device)
seq_len = attn_bias.k_seqinfo.seqlen
B = len(seq_len)
G, H, Kq = q.shape[-3:]
Kkv = v.shape[-1]
# assume kv has been padded
q = q.reshape(B, -1, G, H, Kq)
k = k.reshape(B, -1, G, H, Kkv)
v = v.reshape(B, -1, G, H, Kkv)
# Transpose in the case of MQA/GQA
mqa_swap_seqlen_head = False
if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0:
mqa_swap_seqlen_head = True
assert q.shape[1] == 1
q = q.transpose(1, 3)
k = k[:, :, :, :1]
v = v[:, :, :, :1]
if k.dtype == torch.int32:
# Quantized K/V
PACKED_PER_VAL = 8
Lk = (k.shape[-1] - cls.NUM_GROUPS) * 8
else:
Lk = k.shape[-1]
PACKED_PER_VAL = 1
B, Mk, G, H, Kkv = k.shape
B, M, G, H, Kq = q.shape
assert Lk == Kq, f"Keys have head dim {Lk} but queries have head dim {Kq}"
BLOCK_M = cls.BLOCK_M
BLOCK_N = cls.BLOCK_N
if cls.SPLIT_K is not None:
split_k = cls.SPLIT_K
else:
# Use heuristics
split_k = cls.get_split_k(B, H, Mk)
M_ceil = (M + BLOCK_M - 1) // BLOCK_M * BLOCK_M
o_splitk = torch.empty(
[B * G * H, split_k, M_ceil, Kq], dtype=torch.float32, device=q.device
)
metadata = torch.empty(
[B * G * H, 2, split_k, M_ceil], dtype=torch.float32, device=q.device
)
lse = torch.empty((B * G * H, M), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(M, BLOCK_M), B * G * H, split_k)
num_warps = 2
split_size = (Mk + split_k - 1) // split_k
use_seq_len = seq_len is not None
_fwd_kernel_splitK_unrolled = unroll_varargs(
_fwd_kernel_splitK, N=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1
)
_fwd_kernel_splitK_unrolled[grid](
Q=q,
K=k,
V=v,
sm_scale=inp.scale_float,
Out_splitK=o_splitk,
Metadata=metadata,
Seq_len=seq_len,
**_strides(q, "qz", "qm", "qg", "qh", "qk"),
**_strides(k, "kz", "kn", "kg", "kh", "kk"),
**_strides(v, "vz", "vn", "vg", "vh", "vk"),
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
Z=B,
H=H,
G=G,
N_CTX_Q=M,
N_CTX_K=Mk,
BLOCK_N_PER_SPLIT=split_size,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=Lk,
BOUNDS_CHECKS_N=(split_size % BLOCK_N) > 0 or use_seq_len,
USE_SEQ_LEN=use_seq_len,
num_warps=num_warps,
num_stages=1,
PACKED_PER_VAL=PACKED_PER_VAL,
N_GROUPS=cls.NUM_GROUPS if PACKED_PER_VAL > 1 else 1,
)
if mqa_swap_seqlen_head:
out = torch.empty(
(B, H, G, M, Kq), device=q.device, dtype=q.dtype
).transpose(1, 3)
else:
out = torch.empty((B, M, G, H, Kq), device=q.device, dtype=q.dtype)
# Merge together
grid = (B * G * H, M, 1)
_splitK_reduce[grid](
o_splitk,
metadata,
out,
lse,
split_k=split_k,
**_strides(o_splitk, "osk_zhg", "osk_s", "osk_m", "osk_k"),
**_strides(metadata, "mzhg", "m2", "ms", "mm"),
**_strides(out, "oz", "om", "og", "oh", "ok"),
**_strides(lse, "lse_zhg", "lse_m"),
BLOCK_SIZE=out.shape[-1],
G=G,
H=H,
# TODO: Tune num_warps
)
lse = lse.reshape([B, G, H, M])
if mqa_swap_seqlen_head:
# H/M dimensions have been swapped
out = out.transpose(1, 3)
lse = lse.transpose(2, 3)
if inp.query.ndim == 4:
# BMGHK -> BMHK
assert G == 1
out = out[:, :, 0]
lse = lse[:, 0]
return out, Context(out=out, lse=lse)
class FwOp_S1(FwOp):
SPLIT_K = 1
NAME = "triton_splitK1"
class FwOp_S2(FwOp):
SPLIT_K = 2
NAME = "triton_splitK2"
class FwOp_S4(FwOp):
SPLIT_K = 4
NAME = "triton_splitK4"
class FwOp_S8(FwOp):
SPLIT_K = 8
NAME = "triton_splitK8"
class FwOp_S16(FwOp):
SPLIT_K = 16
NAME = "triton_splitK16"
class FwOp_S32(FwOp):
SPLIT_K = 32
NAME = "triton_splitK32"
class FwOp_S64(FwOp):
SPLIT_K = 64
NAME = "triton_splitK64"
class FwOp_S128(FwOp):
SPLIT_K = 128
NAME = "triton_splitK128"