Files
2026-01-19 10:38:50 +08:00

1352 lines
37 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
import torch.nn as nn
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, next_power_of_2
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .cumsum import chunk_local_cumsum
from .fused_recurrent import fused_recurrent_gated_delta_rule_fwd_kernel
from .index import prepare_chunk_indices
from .l2norm import l2norm_fwd
from .op import exp, log
from .solve_tril import solve_tril
from .utils import is_amd
BT_LIST_AUTOTUNE = [32, 64, 128]
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
def fused_recurrent_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8)
NK, NV = cdiv(K, BK), cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
o = torch.empty_like(k)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
grid = (NK, NV, N * HV)
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
q=q,
k=k,
v=v,
g=g,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
INPLACE_FINAL_STATE=inplace_final_state,
IS_KDA=True,
num_warps=num_warps,
num_stages=num_stages,
)
return o, final_state
def fused_recurrent_kda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = fused_recurrent_kda_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=None,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
@triton.heuristics(
{
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
"HAS_WEIGHT": lambda args: args["w"] is not None,
"HAS_BIAS": lambda args: args["b"] is not None,
}
)
@triton.jit
def layer_norm_gated_fwd_kernel(
x, # pointer to the input
g, # pointer to the gate
y, # pointer to the output
w, # pointer to the weights
b, # pointer to the biases
residual, # pointer to the residual
residual_out, # pointer to the residual
mean, # pointer to the mean
rstd, # pointer to the 1/std
eps, # epsilon to avoid division by zero
T, # number of rows in x
D: tl.constexpr, # number of columns in x
BT: tl.constexpr,
BD: tl.constexpr,
ACTIVATION: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t = tl.program_id(0)
o_d = tl.arange(0, BD)
m_d = o_d < D
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
if HAS_RESIDUAL:
p_res = tl.make_block_ptr(
residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
)
b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)
if STORE_RESIDUAL_OUT:
p_res_out = tl.make_block_ptr(
residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
)
tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))
if not IS_RMS_NORM:
b_mean = tl.sum(b_x, axis=1) / D
p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))
b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
else:
b_xbar = tl.where(m_d[None, :], b_x, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
b_rstd = 1 / tl.sqrt(b_var + eps)
p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))
if HAS_WEIGHT:
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
if HAS_BIAS:
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
b_x_hat = (
(b_x - b_mean[:, None]) * b_rstd[:, None]
if not IS_RMS_NORM
else b_x * b_rstd[:, None]
)
b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat
if HAS_BIAS:
b_y = b_y + b_b[None, :]
# swish/sigmoid output gate
p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
b_y = b_y * b_g * tl.sigmoid(b_g)
elif ACTIVATION == "sigmoid":
b_y = b_y * tl.sigmoid(b_g)
# Write output
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics(
{
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
"HAS_WEIGHT": lambda args: args["w"] is not None,
"HAS_BIAS": lambda args: args["b"] is not None,
}
)
@triton.jit
def layer_norm_gated_fwd_kernel1(
x, # pointer to the input
g, # pointer to the gate
y, # pointer to the output
w, # pointer to the weights
b, # pointer to the biases
residual, # pointer to the residual
residual_out, # pointer to the residual
mean, # pointer to the mean
rstd, # pointer to the 1/std
eps, # epsilon to avoid division by zero
D: tl.constexpr, # number of columns in x
BD: tl.constexpr,
ACTIVATION: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
g += i_t * D
if HAS_RESIDUAL:
residual += i_t * D
if STORE_RESIDUAL_OUT:
residual_out += i_t * D
o_d = tl.arange(0, BD)
m_d = o_d < D
b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32)
if HAS_RESIDUAL:
b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32)
if STORE_RESIDUAL_OUT:
tl.store(residual_out + o_d, b_x, mask=m_d)
if not IS_RMS_NORM:
b_mean = tl.sum(b_x, axis=0) / D
tl.store(mean + i_t, b_mean)
b_xbar = tl.where(m_d, b_x - b_mean, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
else:
b_xbar = tl.where(m_d, b_x, 0.0)
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
b_rstd = 1 / tl.sqrt(b_var + eps)
tl.store(rstd + i_t, b_rstd)
if HAS_WEIGHT:
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
if HAS_BIAS:
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd
b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat
if HAS_BIAS:
b_y = b_y + b_b
# swish/sigmoid output gate
b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32)
if ACTIVATION == "swish" or ACTIVATION == "silu":
b_y = b_y * b_g * tl.sigmoid(b_g)
elif ACTIVATION == "sigmoid":
b_y = b_y * tl.sigmoid(b_g)
# Write output
tl.store(y + o_d, b_y, mask=m_d)
def layer_norm_gated_fwd(
x: torch.Tensor,
g: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
activation: str = "swish",
eps: float = 1e-5,
residual: torch.Tensor = None,
out_dtype: torch.dtype = None,
residual_dtype: torch.dtype = None,
is_rms_norm: bool = False,
):
if residual is not None:
residual_dtype = residual.dtype
T, D = x.shape
if residual is not None:
assert residual.shape == (T, D)
if weight is not None:
assert weight.shape == (D,)
if bias is not None:
assert bias.shape == (D,)
# allocate output
y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype)
if residual is not None or (
residual_dtype is not None and residual_dtype != x.dtype
):
residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)
else:
residual_out = None
mean = (
torch.empty((T,), dtype=torch.float, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((T,), dtype=torch.float, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
if D <= 512:
BT = 32
layer_norm_gated_fwd_kernel[(cdiv(T, BT),)](
x=x,
g=g,
y=y,
w=weight,
b=bias,
residual=residual,
residual_out=residual_out,
mean=mean,
rstd=rstd,
eps=eps,
T=T,
D=D,
BD=BD,
BT=BT,
ACTIVATION=activation,
IS_RMS_NORM=is_rms_norm,
num_warps=4,
)
else:
layer_norm_gated_fwd_kernel1[(T,)](
x=x,
g=g,
y=y,
w=weight,
b=bias,
residual=residual,
residual_out=residual_out,
mean=mean,
rstd=rstd,
eps=eps,
D=D,
BD=BD,
ACTIVATION=activation,
IS_RMS_NORM=is_rms_norm,
num_warps=4,
)
# residual_out is None if residual is None and residual_dtype == input_dtype
return y, mean, rstd, residual_out if residual_out is not None else x
def rms_norm_gated(
x: torch.Tensor,
g: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
activation: str = "swish",
residual: torch.Tensor | None = None,
prenorm: bool = False,
residual_in_fp32: bool = False,
eps: float = 1e-6,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.contiguous().reshape(-1, x.shape[-1])
g = g.contiguous().reshape(-1, g.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.contiguous().reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float if residual_in_fp32 else None)
)
y, _, _, residual_out = layer_norm_gated_fwd(
x=x,
g=g,
weight=weight,
bias=bias,
activation=activation,
eps=eps,
residual=residual,
residual_dtype=residual_dtype,
is_rms_norm=True,
)
y = y.reshape(x_shape_og)
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
class FusedRMSNormGated(nn.Module):
def __init__(
self,
hidden_size: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
activation: str = "swish",
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.hidden_size = hidden_size
self.elementwise_affine = elementwise_affine
self.eps = eps
self.activation = activation
if self.activation not in ["swish", "silu", "sigmoid"]:
raise ValueError(f"Unsupported activation: {self.activation}")
if elementwise_affine:
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
def forward(
self,
x: torch.Tensor,
g: torch.Tensor,
residual: torch.Tensor | None = None,
prenorm: bool = False,
residual_in_fp32: bool = False,
) -> torch.Tensor:
return rms_norm_gated(
x,
g,
self.weight,
self.bias,
self.activation,
residual=residual,
eps=self.eps,
prenorm=prenorm,
residual_in_fp32=residual_in_fp32,
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64]
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["BC"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
q,
k,
g,
beta,
A,
Aqk,
scale,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
NC: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
i_i, i_j = i_c // NC, i_c % NC
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
if i_i <= i_j:
return
q += (bos * H + i_h) * K
k += (bos * H + i_h) * K
g += (bos * H + i_h) * K
A += (bos * H + i_h) * BT
Aqk += (bos * H + i_h) * BT
p_b = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)
)
b_b = tl.load(p_b, boundary_check=(0,))
b_A = tl.zeros([BC, BC], dtype=tl.float32)
b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
p_g = tl.make_block_ptr(
g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
)
b_kt = tl.make_block_ptr(
k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
p_gk = tl.make_block_ptr(
g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
)
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
# [BK,]
b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
# [BC, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
# [BK, BC]
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kt = tl.load(b_kt, boundary_check=(0, 1))
# [BC, BC]
b_ktg = b_kt * exp(b_gn[:, None] - b_gk)
b_A += tl.dot(b_k, b_ktg)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
b_Aqk += tl.dot(b_qg, b_ktg)
b_A *= b_b[:, None]
p_A = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
)
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
p_Aqk = tl.make_block_ptr(
Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
)
tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["BK", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
q,
k,
g,
beta,
A,
Aqk,
scale,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BC: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if i_t * BT + i_i * BC >= T:
return
o_i = tl.arange(0, BC)
o_k = tl.arange(0, BK)
m_k = o_k < K
m_A = (i_t * BT + i_i * BC + o_i) < T
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT + i_i * BC, 0),
(BC, BK),
(1, 0),
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_k = tl.load(p_k, boundary_check=(0, 1))
b_g = tl.load(p_g, boundary_check=(0, 1))
p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None]
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :])
b_A = tl.sum(b_k * b_ktg, 1)
b_A = tl.where(o_i > j, b_A, 0.0)
b_Aqk = tl.sum(b_q * b_ktg, 1)
b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0)
tl.store(A + o_A + j, b_A, mask=m_A)
tl.store(Aqk + o_A + j, b_Aqk, mask=m_A)
p_kt += H * K
p_gk += H * K
def chunk_kda_scaled_dot_kkt_fwd(
q: torch.Tensor,
k: torch.Tensor,
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, H, K = k.shape
assert K <= 256
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BC = min(16, BT)
NC = cdiv(BT, BC)
BK = max(next_power_of_2(K), 16)
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
grid = (NT, NC * NC, B * H)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
q=q,
k=k,
g=gk,
beta=beta,
A=A,
Aqk=Aqk,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
NC=NC,
)
grid = (NT, NC, B * H)
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
q=q,
k=k,
g=gk,
beta=beta,
A=A,
Aqk=Aqk,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
BT=BT,
BC=BC,
BK=BK,
)
return A, Aqk
@triton.heuristics(
{
"STORE_QG": lambda args: args["qg"] is not None,
"STORE_KG": lambda args: args["kg"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
q,
k,
qg,
kg,
v,
beta,
w,
u,
A,
gk,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
STORE_QG: tl.constexpr,
STORE_KG: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_b = tl.load(p_b, boundary_check=(0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_A = tl.load(p_A, boundary_check=(0, 1))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_b[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_k = tl.make_block_ptr(
k + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_b[:, None]
p_gk = tl.make_block_ptr(
gk + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_gk = tl.load(p_gk, boundary_check=(0, 1))
b_kb *= exp(b_gk)
if STORE_QG:
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_qg = tl.make_block_ptr(
qg + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_q = tl.load(p_q, boundary_check=(0, 1))
b_qg = b_q * exp(b_gk)
tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1))
if STORE_KG:
last_idx = min(i_t * BT + BT, T) - 1
o_k = i_k * BK + tl.arange(0, BK)
m_k = o_k < K
b_gn = tl.load(
gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0
)
b_kg = b_k * exp(b_gn - b_gk)
p_kg = tl.make_block_ptr(
kg + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1))
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
BK = 64
BV = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
w = torch.empty_like(k)
u = torch.empty_like(v)
kg = torch.empty_like(k) if gk is not None else None
recompute_w_u_fwd_kernel[(NT, B * H)](
q=q,
k=k,
qg=None,
kg=kg,
v=v,
beta=beta,
w=w,
u=u,
A=A,
gk=gk,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
DOT_PRECISION="ieee",
)
return w, u, None, kg
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64]
for BV in [64, 128]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_gla_fwd_kernel_o(
q,
v,
g,
h,
o,
A,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_tg = i_t
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
b_o = tl.zeros([BT, BV], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_g = tl.make_block_ptr(
g + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_h = tl.make_block_ptr(
h + (i_tg * H + i_h) * K * V,
(K, V),
(V, 1),
(i_k * BK, i_v * BV),
(BK, BV),
(1, 0),
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
b_q = (b_q * scale).to(b_q.dtype)
# [BT, BK]
b_g = tl.load(p_g, boundary_check=(0, 1))
# [BT, BK]
b_qg = (b_q * exp(b_g)).to(b_q.dtype)
# [BK, BV]
b_h = tl.load(p_h, boundary_check=(0, 1))
# works but dkw, owing to divine benevolence
# [BT, BV]
if i_k >= 0:
b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_o = tl.make_block_ptr(
o + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
# [BT, BV]
b_v = tl.load(p_v, boundary_check=(0, 1))
# [BT, BT]
b_A = tl.load(p_A, boundary_check=(0, 1))
b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype)
b_o += tl.dot(b_A, b_v, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_gla_fwd_o_gk(
q: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
A: torch.Tensor,
h: torch.Tensor,
o: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
def grid(meta):
return (cdiv(V, meta["BV"]), NT, B * H)
chunk_gla_fwd_kernel_o[grid](
q=q,
v=v,
g=g,
h=h,
o=o,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
scale=scale,
T=T,
H=H,
K=K,
V=V,
BT=BT,
)
return o
def chunk_kda_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
# the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput
A, Aqk = chunk_kda_scaled_dot_kkt_fwd(
q=q,
k=k,
gk=g,
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
output_dtype=torch.float32,
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u, _, kg = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
gk=g,
cu_seqlens=cu_seqlens,
)
del A
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=kg,
w=w,
u=u,
gk=g,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
del w, u, kg
o = chunk_gla_fwd_o_gk(
q=q,
v=v_new,
g=g,
A=Aqk,
h=h,
o=v,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
)
del Aqk, v_new, h
return o, final_state
def chunk_kda(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
**kwargs,
):
if scale is None:
scale = k.shape[-1] ** -0.5
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q.contiguous())
k = l2norm_fwd(k.contiguous())
o, final_state = chunk_kda_fwd(
q=q,
k=k,
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state.contiguous(),
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
return o, final_state
@triton.autotune(
configs=[
triton.Config({"BT": bt}, num_warps=nw, num_stages=ns)
for bt in BT_LIST_AUTOTUNE
for nw in NUM_WARPS_AUTOTUNE
for ns in [2, 3]
],
key=["H", "D"],
)
@triton.jit
def kda_gate_fwd_kernel(
g,
A,
y,
g_bias,
beta: tl.constexpr,
threshold: tl.constexpr,
T,
H,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
i_t, i_h = tl.program_id(0), tl.program_id(1)
n_t = i_t * BT
b_a = tl.load(A + i_h).to(tl.float32)
b_a = -tl.exp(b_a)
stride_row = H * D
stride_col = 1
g_ptr = tl.make_block_ptr(
base=g + i_h * D,
shape=(T, D),
strides=(stride_row, stride_col),
offsets=(n_t, 0),
block_shape=(BT, BD),
order=(1, 0),
)
y_ptr = tl.make_block_ptr(
base=y + i_h * D,
shape=(T, D),
strides=(stride_row, stride_col),
offsets=(n_t, 0),
block_shape=(BT, BD),
order=(1, 0),
)
b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32)
if HAS_BIAS:
n_d = tl.arange(0, BD)
bias_mask = n_d < D
b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to(
tl.float32
)
b_g = b_g + b_bias[None, :]
# softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
# When beta * x > threshold, use linear approximation x
# Use threshold to switch to linear when beta*x > threshold
g_scaled = b_g * beta
use_linear = g_scaled > threshold
sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled)))
b_y = b_a * sp
tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
def fused_kda_gate(
g: torch.Tensor,
A: torch.Tensor,
head_k_dim: int,
g_bias: torch.Tensor | None = None,
beta: float = 1.0,
threshold: float = 20.0,
) -> torch.Tensor:
"""
Forward pass for KDA gate:
input g: [..., H*D]
param A: [H] or [1, 1, H, 1]
beta: softplus beta parameter
threshold: softplus threshold parameter
return : [..., H, D]
"""
orig_shape = g.shape[:-1]
g = g.view(-1, g.shape[-1])
T = g.shape[0]
HD = g.shape[1]
H = A.numel()
assert H * head_k_dim == HD
y = torch.empty_like(g, dtype=torch.float32)
def grid(meta):
return (cdiv(T, meta["BT"]), H)
kda_gate_fwd_kernel[grid](
g,
A,
y,
g_bias,
beta,
threshold,
T,
H,
head_k_dim,
BD=next_power_of_2(head_k_dim),
HAS_BIAS=g_bias is not None,
)
y = y.view(*orig_shape, H, head_k_dim)
return y