1352 lines
37 KiB
Python
1352 lines
37 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.utils.math_utils import cdiv, next_power_of_2
|
|
|
|
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
|
from .cumsum import chunk_local_cumsum
|
|
from .fused_recurrent import fused_recurrent_gated_delta_rule_fwd_kernel
|
|
from .index import prepare_chunk_indices
|
|
from .l2norm import l2norm_fwd
|
|
from .op import exp, log
|
|
from .solve_tril import solve_tril
|
|
from .utils import is_amd
|
|
|
|
BT_LIST_AUTOTUNE = [32, 64, 128]
|
|
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
|
|
|
|
|
|
def fused_recurrent_kda_fwd(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
scale: float,
|
|
initial_state: torch.Tensor,
|
|
inplace_final_state: bool = True,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
ssm_state_indices: torch.Tensor | None = None,
|
|
num_accepted_tokens: torch.Tensor | None = None,
|
|
use_qk_l2norm_in_kernel: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
B, T, H, K, V = *k.shape, v.shape[-1]
|
|
HV = v.shape[2]
|
|
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
|
BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8)
|
|
NK, NV = cdiv(K, BK), cdiv(V, BV)
|
|
assert NK == 1, "NK > 1 is not supported yet"
|
|
num_stages = 3
|
|
num_warps = 1
|
|
|
|
o = torch.empty_like(k)
|
|
if inplace_final_state:
|
|
final_state = initial_state
|
|
else:
|
|
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
|
|
|
stride_init_state_token = initial_state.stride(0)
|
|
stride_final_state_token = final_state.stride(0)
|
|
|
|
if ssm_state_indices is None:
|
|
stride_indices_seq, stride_indices_tok = 1, 1
|
|
elif ssm_state_indices.ndim == 1:
|
|
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
|
|
else:
|
|
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
|
|
|
|
grid = (NK, NV, N * HV)
|
|
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
|
|
q=q,
|
|
k=k,
|
|
v=v,
|
|
g=g,
|
|
beta=beta,
|
|
o=o,
|
|
h0=initial_state,
|
|
ht=final_state,
|
|
cu_seqlens=cu_seqlens,
|
|
ssm_state_indices=ssm_state_indices,
|
|
num_accepted_tokens=num_accepted_tokens,
|
|
scale=scale,
|
|
N=N,
|
|
T=T,
|
|
B=B,
|
|
H=H,
|
|
HV=HV,
|
|
K=K,
|
|
V=V,
|
|
BK=BK,
|
|
BV=BV,
|
|
stride_init_state_token=stride_init_state_token,
|
|
stride_final_state_token=stride_final_state_token,
|
|
stride_indices_seq=stride_indices_seq,
|
|
stride_indices_tok=stride_indices_tok,
|
|
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
|
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
|
INPLACE_FINAL_STATE=inplace_final_state,
|
|
IS_KDA=True,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
)
|
|
|
|
return o, final_state
|
|
|
|
|
|
def fused_recurrent_kda(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
beta: torch.Tensor = None,
|
|
scale: float = None,
|
|
initial_state: torch.Tensor = None,
|
|
inplace_final_state: bool = True,
|
|
use_qk_l2norm_in_kernel: bool = True,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
ssm_state_indices: torch.LongTensor | None = None,
|
|
**kwargs,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if cu_seqlens is not None and q.shape[0] != 1:
|
|
raise ValueError(
|
|
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
|
f"Please flatten variable-length inputs before processing."
|
|
)
|
|
if scale is None:
|
|
scale = k.shape[-1] ** -0.5
|
|
|
|
o, final_state = fused_recurrent_kda_fwd(
|
|
q=q.contiguous(),
|
|
k=k.contiguous(),
|
|
v=v.contiguous(),
|
|
g=g.contiguous(),
|
|
beta=beta.contiguous(),
|
|
scale=scale,
|
|
initial_state=initial_state,
|
|
inplace_final_state=inplace_final_state,
|
|
cu_seqlens=cu_seqlens,
|
|
ssm_state_indices=ssm_state_indices,
|
|
num_accepted_tokens=None,
|
|
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
|
)
|
|
return o, final_state
|
|
|
|
|
|
@triton.heuristics(
|
|
{
|
|
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
|
|
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
|
|
"HAS_WEIGHT": lambda args: args["w"] is not None,
|
|
"HAS_BIAS": lambda args: args["b"] is not None,
|
|
}
|
|
)
|
|
@triton.jit
|
|
def layer_norm_gated_fwd_kernel(
|
|
x, # pointer to the input
|
|
g, # pointer to the gate
|
|
y, # pointer to the output
|
|
w, # pointer to the weights
|
|
b, # pointer to the biases
|
|
residual, # pointer to the residual
|
|
residual_out, # pointer to the residual
|
|
mean, # pointer to the mean
|
|
rstd, # pointer to the 1/std
|
|
eps, # epsilon to avoid division by zero
|
|
T, # number of rows in x
|
|
D: tl.constexpr, # number of columns in x
|
|
BT: tl.constexpr,
|
|
BD: tl.constexpr,
|
|
ACTIVATION: tl.constexpr,
|
|
IS_RMS_NORM: tl.constexpr,
|
|
STORE_RESIDUAL_OUT: tl.constexpr,
|
|
HAS_RESIDUAL: tl.constexpr,
|
|
HAS_WEIGHT: tl.constexpr,
|
|
HAS_BIAS: tl.constexpr,
|
|
):
|
|
i_t = tl.program_id(0)
|
|
|
|
o_d = tl.arange(0, BD)
|
|
m_d = o_d < D
|
|
|
|
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
|
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
|
if HAS_RESIDUAL:
|
|
p_res = tl.make_block_ptr(
|
|
residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
|
|
)
|
|
b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32)
|
|
if STORE_RESIDUAL_OUT:
|
|
p_res_out = tl.make_block_ptr(
|
|
residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)
|
|
)
|
|
tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1))
|
|
if not IS_RMS_NORM:
|
|
b_mean = tl.sum(b_x, axis=1) / D
|
|
p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
|
tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,))
|
|
b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0)
|
|
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
|
|
else:
|
|
b_xbar = tl.where(m_d[None, :], b_x, 0.0)
|
|
b_var = tl.sum(b_xbar * b_xbar, axis=1) / D
|
|
b_rstd = 1 / tl.sqrt(b_var + eps)
|
|
|
|
p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,))
|
|
tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,))
|
|
|
|
if HAS_WEIGHT:
|
|
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
|
|
if HAS_BIAS:
|
|
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
|
|
b_x_hat = (
|
|
(b_x - b_mean[:, None]) * b_rstd[:, None]
|
|
if not IS_RMS_NORM
|
|
else b_x * b_rstd[:, None]
|
|
)
|
|
b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat
|
|
if HAS_BIAS:
|
|
b_y = b_y + b_b[None, :]
|
|
|
|
# swish/sigmoid output gate
|
|
p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
|
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32)
|
|
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
|
b_y = b_y * b_g * tl.sigmoid(b_g)
|
|
elif ACTIVATION == "sigmoid":
|
|
b_y = b_y * tl.sigmoid(b_g)
|
|
|
|
# Write output
|
|
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
|
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
@triton.heuristics(
|
|
{
|
|
"STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None,
|
|
"HAS_RESIDUAL": lambda args: args["residual"] is not None,
|
|
"HAS_WEIGHT": lambda args: args["w"] is not None,
|
|
"HAS_BIAS": lambda args: args["b"] is not None,
|
|
}
|
|
)
|
|
@triton.jit
|
|
def layer_norm_gated_fwd_kernel1(
|
|
x, # pointer to the input
|
|
g, # pointer to the gate
|
|
y, # pointer to the output
|
|
w, # pointer to the weights
|
|
b, # pointer to the biases
|
|
residual, # pointer to the residual
|
|
residual_out, # pointer to the residual
|
|
mean, # pointer to the mean
|
|
rstd, # pointer to the 1/std
|
|
eps, # epsilon to avoid division by zero
|
|
D: tl.constexpr, # number of columns in x
|
|
BD: tl.constexpr,
|
|
ACTIVATION: tl.constexpr,
|
|
IS_RMS_NORM: tl.constexpr,
|
|
STORE_RESIDUAL_OUT: tl.constexpr,
|
|
HAS_RESIDUAL: tl.constexpr,
|
|
HAS_WEIGHT: tl.constexpr,
|
|
HAS_BIAS: tl.constexpr,
|
|
):
|
|
i_t = tl.program_id(0)
|
|
x += i_t * D
|
|
y += i_t * D
|
|
g += i_t * D
|
|
if HAS_RESIDUAL:
|
|
residual += i_t * D
|
|
if STORE_RESIDUAL_OUT:
|
|
residual_out += i_t * D
|
|
|
|
o_d = tl.arange(0, BD)
|
|
m_d = o_d < D
|
|
b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32)
|
|
if HAS_RESIDUAL:
|
|
b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32)
|
|
if STORE_RESIDUAL_OUT:
|
|
tl.store(residual_out + o_d, b_x, mask=m_d)
|
|
if not IS_RMS_NORM:
|
|
b_mean = tl.sum(b_x, axis=0) / D
|
|
tl.store(mean + i_t, b_mean)
|
|
b_xbar = tl.where(m_d, b_x - b_mean, 0.0)
|
|
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
|
|
else:
|
|
b_xbar = tl.where(m_d, b_x, 0.0)
|
|
b_var = tl.sum(b_xbar * b_xbar, axis=0) / D
|
|
b_rstd = 1 / tl.sqrt(b_var + eps)
|
|
tl.store(rstd + i_t, b_rstd)
|
|
|
|
if HAS_WEIGHT:
|
|
b_w = tl.load(w + o_d, mask=m_d).to(tl.float32)
|
|
if HAS_BIAS:
|
|
b_b = tl.load(b + o_d, mask=m_d).to(tl.float32)
|
|
b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd
|
|
b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat
|
|
if HAS_BIAS:
|
|
b_y = b_y + b_b
|
|
|
|
# swish/sigmoid output gate
|
|
b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32)
|
|
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
|
b_y = b_y * b_g * tl.sigmoid(b_g)
|
|
elif ACTIVATION == "sigmoid":
|
|
b_y = b_y * tl.sigmoid(b_g)
|
|
|
|
# Write output
|
|
tl.store(y + o_d, b_y, mask=m_d)
|
|
|
|
|
|
def layer_norm_gated_fwd(
|
|
x: torch.Tensor,
|
|
g: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
activation: str = "swish",
|
|
eps: float = 1e-5,
|
|
residual: torch.Tensor = None,
|
|
out_dtype: torch.dtype = None,
|
|
residual_dtype: torch.dtype = None,
|
|
is_rms_norm: bool = False,
|
|
):
|
|
if residual is not None:
|
|
residual_dtype = residual.dtype
|
|
T, D = x.shape
|
|
if residual is not None:
|
|
assert residual.shape == (T, D)
|
|
if weight is not None:
|
|
assert weight.shape == (D,)
|
|
if bias is not None:
|
|
assert bias.shape == (D,)
|
|
# allocate output
|
|
y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype)
|
|
if residual is not None or (
|
|
residual_dtype is not None and residual_dtype != x.dtype
|
|
):
|
|
residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype)
|
|
else:
|
|
residual_out = None
|
|
mean = (
|
|
torch.empty((T,), dtype=torch.float, device=x.device)
|
|
if not is_rms_norm
|
|
else None
|
|
)
|
|
rstd = torch.empty((T,), dtype=torch.float, device=x.device)
|
|
# Less than 64KB per feature: enqueue fused kernel
|
|
MAX_FUSED_SIZE = 65536 // x.element_size()
|
|
BD = min(MAX_FUSED_SIZE, next_power_of_2(D))
|
|
if D > BD:
|
|
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
|
# heuristics for number of warps
|
|
|
|
if D <= 512:
|
|
BT = 32
|
|
layer_norm_gated_fwd_kernel[(cdiv(T, BT),)](
|
|
x=x,
|
|
g=g,
|
|
y=y,
|
|
w=weight,
|
|
b=bias,
|
|
residual=residual,
|
|
residual_out=residual_out,
|
|
mean=mean,
|
|
rstd=rstd,
|
|
eps=eps,
|
|
T=T,
|
|
D=D,
|
|
BD=BD,
|
|
BT=BT,
|
|
ACTIVATION=activation,
|
|
IS_RMS_NORM=is_rms_norm,
|
|
num_warps=4,
|
|
)
|
|
else:
|
|
layer_norm_gated_fwd_kernel1[(T,)](
|
|
x=x,
|
|
g=g,
|
|
y=y,
|
|
w=weight,
|
|
b=bias,
|
|
residual=residual,
|
|
residual_out=residual_out,
|
|
mean=mean,
|
|
rstd=rstd,
|
|
eps=eps,
|
|
D=D,
|
|
BD=BD,
|
|
ACTIVATION=activation,
|
|
IS_RMS_NORM=is_rms_norm,
|
|
num_warps=4,
|
|
)
|
|
# residual_out is None if residual is None and residual_dtype == input_dtype
|
|
return y, mean, rstd, residual_out if residual_out is not None else x
|
|
|
|
|
|
def rms_norm_gated(
|
|
x: torch.Tensor,
|
|
g: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: torch.Tensor,
|
|
activation: str = "swish",
|
|
residual: torch.Tensor | None = None,
|
|
prenorm: bool = False,
|
|
residual_in_fp32: bool = False,
|
|
eps: float = 1e-6,
|
|
):
|
|
x_shape_og = x.shape
|
|
# reshape input data into 2D tensor
|
|
x = x.contiguous().reshape(-1, x.shape[-1])
|
|
g = g.contiguous().reshape(-1, g.shape[-1])
|
|
if residual is not None:
|
|
assert residual.shape == x_shape_og
|
|
residual = residual.contiguous().reshape(-1, residual.shape[-1])
|
|
residual_dtype = (
|
|
residual.dtype
|
|
if residual is not None
|
|
else (torch.float if residual_in_fp32 else None)
|
|
)
|
|
y, _, _, residual_out = layer_norm_gated_fwd(
|
|
x=x,
|
|
g=g,
|
|
weight=weight,
|
|
bias=bias,
|
|
activation=activation,
|
|
eps=eps,
|
|
residual=residual,
|
|
residual_dtype=residual_dtype,
|
|
is_rms_norm=True,
|
|
)
|
|
y = y.reshape(x_shape_og)
|
|
return y if not prenorm else (y, residual_out.reshape(x_shape_og))
|
|
|
|
|
|
class FusedRMSNormGated(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
elementwise_affine: bool = True,
|
|
eps: float = 1e-5,
|
|
activation: str = "swish",
|
|
device: torch.device | None = None,
|
|
dtype: torch.dtype | None = None,
|
|
) -> None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
super().__init__()
|
|
|
|
self.hidden_size = hidden_size
|
|
self.elementwise_affine = elementwise_affine
|
|
self.eps = eps
|
|
self.activation = activation
|
|
|
|
if self.activation not in ["swish", "silu", "sigmoid"]:
|
|
raise ValueError(f"Unsupported activation: {self.activation}")
|
|
|
|
if elementwise_affine:
|
|
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
|
else:
|
|
self.register_parameter("weight", None)
|
|
self.register_parameter("bias", None)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
g: torch.Tensor,
|
|
residual: torch.Tensor | None = None,
|
|
prenorm: bool = False,
|
|
residual_in_fp32: bool = False,
|
|
) -> torch.Tensor:
|
|
return rms_norm_gated(
|
|
x,
|
|
g,
|
|
self.weight,
|
|
self.bias,
|
|
self.activation,
|
|
residual=residual,
|
|
eps=self.eps,
|
|
prenorm=prenorm,
|
|
residual_in_fp32=residual_in_fp32,
|
|
)
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
|
|
for BK in [32, 64]
|
|
for num_warps in [1, 2, 4, 8]
|
|
for num_stages in [2, 3, 4]
|
|
],
|
|
key=["BC"],
|
|
)
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter(
|
|
q,
|
|
k,
|
|
g,
|
|
beta,
|
|
A,
|
|
Aqk,
|
|
scale,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H: tl.constexpr,
|
|
K: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BC: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
NC: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
i_i, i_j = i_c // NC, i_c % NC
|
|
if IS_VARLEN:
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
if i_t * BT + i_i * BC >= T:
|
|
return
|
|
if i_i <= i_j:
|
|
return
|
|
|
|
q += (bos * H + i_h) * K
|
|
k += (bos * H + i_h) * K
|
|
g += (bos * H + i_h) * K
|
|
A += (bos * H + i_h) * BT
|
|
Aqk += (bos * H + i_h) * BT
|
|
|
|
p_b = tl.make_block_ptr(
|
|
beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,)
|
|
)
|
|
b_b = tl.load(p_b, boundary_check=(0,))
|
|
|
|
b_A = tl.zeros([BC, BC], dtype=tl.float32)
|
|
b_Aqk = tl.zeros([BC, BC], dtype=tl.float32)
|
|
for i_k in range(tl.cdiv(K, BK)):
|
|
p_q = tl.make_block_ptr(
|
|
q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
|
|
)
|
|
p_k = tl.make_block_ptr(
|
|
k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
|
|
)
|
|
p_g = tl.make_block_ptr(
|
|
g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0)
|
|
)
|
|
b_kt = tl.make_block_ptr(
|
|
k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
|
|
)
|
|
p_gk = tl.make_block_ptr(
|
|
g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1)
|
|
)
|
|
|
|
o_k = i_k * BK + tl.arange(0, BK)
|
|
m_k = o_k < K
|
|
# [BK,]
|
|
b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0)
|
|
# [BC, BK]
|
|
b_g = tl.load(p_g, boundary_check=(0, 1))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :])
|
|
# [BK, BC]
|
|
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
|
b_kt = tl.load(b_kt, boundary_check=(0, 1))
|
|
# [BC, BC]
|
|
b_ktg = b_kt * exp(b_gn[:, None] - b_gk)
|
|
b_A += tl.dot(b_k, b_ktg)
|
|
|
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
|
b_qg = b_q * exp(b_g - b_gn[None, :]) * scale
|
|
b_Aqk += tl.dot(b_qg, b_ktg)
|
|
|
|
b_A *= b_b[:, None]
|
|
|
|
p_A = tl.make_block_ptr(
|
|
A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
|
|
)
|
|
tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))
|
|
p_Aqk = tl.make_block_ptr(
|
|
Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0)
|
|
)
|
|
tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.autotune(
|
|
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
|
|
key=["BK", "BT"],
|
|
)
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra(
|
|
q,
|
|
k,
|
|
g,
|
|
beta,
|
|
A,
|
|
Aqk,
|
|
scale,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H: tl.constexpr,
|
|
K: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BC: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
if i_t * BT + i_i * BC >= T:
|
|
return
|
|
|
|
o_i = tl.arange(0, BC)
|
|
o_k = tl.arange(0, BK)
|
|
m_k = o_k < K
|
|
m_A = (i_t * BT + i_i * BC + o_i) < T
|
|
o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC
|
|
|
|
p_q = tl.make_block_ptr(
|
|
q + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT + i_i * BC, 0),
|
|
(BC, BK),
|
|
(1, 0),
|
|
)
|
|
p_k = tl.make_block_ptr(
|
|
k + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT + i_i * BC, 0),
|
|
(BC, BK),
|
|
(1, 0),
|
|
)
|
|
p_g = tl.make_block_ptr(
|
|
g + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT + i_i * BC, 0),
|
|
(BC, BK),
|
|
(1, 0),
|
|
)
|
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
b_g = tl.load(p_g, boundary_check=(0, 1))
|
|
|
|
p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h
|
|
b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None]
|
|
|
|
p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
|
|
p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k
|
|
|
|
for j in range(0, min(BC, T - i_t * BT - i_i * BC)):
|
|
b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32)
|
|
b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)
|
|
b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :])
|
|
b_A = tl.sum(b_k * b_ktg, 1)
|
|
b_A = tl.where(o_i > j, b_A, 0.0)
|
|
b_Aqk = tl.sum(b_q * b_ktg, 1)
|
|
b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0)
|
|
tl.store(A + o_A + j, b_A, mask=m_A)
|
|
tl.store(Aqk + o_A + j, b_Aqk, mask=m_A)
|
|
p_kt += H * K
|
|
p_gk += H * K
|
|
|
|
|
|
def chunk_kda_scaled_dot_kkt_fwd(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
gk: torch.Tensor | None = None,
|
|
beta: torch.Tensor | None = None,
|
|
scale: float | None = None,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
chunk_size: int = 64,
|
|
output_dtype: torch.dtype = torch.float32,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
r"""
|
|
Compute beta * K * K^T.
|
|
|
|
Args:
|
|
k (torch.Tensor):
|
|
The key tensor of shape `[B, T, H, K]`.
|
|
beta (torch.Tensor):
|
|
The beta tensor of shape `[B, T, H]`.
|
|
gk (torch.Tensor):
|
|
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
|
cu_seqlens (torch.LongTensor):
|
|
The cumulative sequence lengths of the input tensor.
|
|
Default: None
|
|
chunk_size (int):
|
|
The chunk size. Default: 64.
|
|
output_dtype (torch.dtype):
|
|
The dtype of the output tensor. Default: `torch.float32`
|
|
|
|
Returns:
|
|
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
|
"""
|
|
B, T, H, K = k.shape
|
|
assert K <= 256
|
|
BT = chunk_size
|
|
chunk_indices = (
|
|
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
|
)
|
|
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
|
|
|
BC = min(16, BT)
|
|
NC = cdiv(BT, BC)
|
|
BK = max(next_power_of_2(K), 16)
|
|
A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
|
|
Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype)
|
|
grid = (NT, NC * NC, B * H)
|
|
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid](
|
|
q=q,
|
|
k=k,
|
|
g=gk,
|
|
beta=beta,
|
|
A=A,
|
|
Aqk=Aqk,
|
|
scale=scale,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
K=K,
|
|
BT=BT,
|
|
BC=BC,
|
|
NC=NC,
|
|
)
|
|
|
|
grid = (NT, NC, B * H)
|
|
chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid](
|
|
q=q,
|
|
k=k,
|
|
g=gk,
|
|
beta=beta,
|
|
A=A,
|
|
Aqk=Aqk,
|
|
scale=scale,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
K=K,
|
|
BT=BT,
|
|
BC=BC,
|
|
BK=BK,
|
|
)
|
|
return A, Aqk
|
|
|
|
|
|
@triton.heuristics(
|
|
{
|
|
"STORE_QG": lambda args: args["qg"] is not None,
|
|
"STORE_KG": lambda args: args["kg"] is not None,
|
|
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
|
}
|
|
)
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
|
for num_warps in [2, 4, 8]
|
|
for num_stages in [2, 3, 4]
|
|
],
|
|
key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
|
)
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def recompute_w_u_fwd_kernel(
|
|
q,
|
|
k,
|
|
qg,
|
|
kg,
|
|
v,
|
|
beta,
|
|
w,
|
|
u,
|
|
A,
|
|
gk,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H: tl.constexpr,
|
|
K: tl.constexpr,
|
|
V: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
BV: tl.constexpr,
|
|
STORE_QG: tl.constexpr,
|
|
STORE_KG: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
DOT_PRECISION: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
|
b_b = tl.load(p_b, boundary_check=(0,))
|
|
|
|
p_A = tl.make_block_ptr(
|
|
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
|
)
|
|
b_A = tl.load(p_A, boundary_check=(0, 1))
|
|
|
|
for i_v in range(tl.cdiv(V, BV)):
|
|
p_v = tl.make_block_ptr(
|
|
v + (bos * H + i_h) * V,
|
|
(T, V),
|
|
(H * V, 1),
|
|
(i_t * BT, i_v * BV),
|
|
(BT, BV),
|
|
(1, 0),
|
|
)
|
|
p_u = tl.make_block_ptr(
|
|
u + (bos * H + i_h) * V,
|
|
(T, V),
|
|
(H * V, 1),
|
|
(i_t * BT, i_v * BV),
|
|
(BT, BV),
|
|
(1, 0),
|
|
)
|
|
b_v = tl.load(p_v, boundary_check=(0, 1))
|
|
b_vb = (b_v * b_b[:, None]).to(b_v.dtype)
|
|
b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION)
|
|
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
for i_k in range(tl.cdiv(K, BK)):
|
|
p_w = tl.make_block_ptr(
|
|
w + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
p_k = tl.make_block_ptr(
|
|
k + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
b_k = tl.load(p_k, boundary_check=(0, 1))
|
|
b_kb = b_k * b_b[:, None]
|
|
|
|
p_gk = tl.make_block_ptr(
|
|
gk + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
b_gk = tl.load(p_gk, boundary_check=(0, 1))
|
|
b_kb *= exp(b_gk)
|
|
if STORE_QG:
|
|
p_q = tl.make_block_ptr(
|
|
q + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
p_qg = tl.make_block_ptr(
|
|
qg + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
|
b_qg = b_q * exp(b_gk)
|
|
tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1))
|
|
if STORE_KG:
|
|
last_idx = min(i_t * BT + BT, T) - 1
|
|
|
|
o_k = i_k * BK + tl.arange(0, BK)
|
|
m_k = o_k < K
|
|
b_gn = tl.load(
|
|
gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0
|
|
)
|
|
b_kg = b_k * exp(b_gn - b_gk)
|
|
|
|
p_kg = tl.make_block_ptr(
|
|
kg + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
b_w = tl.dot(b_A, b_kb.to(b_k.dtype))
|
|
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
def recompute_w_u_fwd(
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
A: torch.Tensor,
|
|
q: torch.Tensor | None = None,
|
|
gk: torch.Tensor | None = None,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
B, T, H, K, V = *k.shape, v.shape[-1]
|
|
BT = A.shape[-1]
|
|
BK = 64
|
|
BV = 64
|
|
|
|
chunk_indices = (
|
|
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
|
)
|
|
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
|
|
|
w = torch.empty_like(k)
|
|
u = torch.empty_like(v)
|
|
kg = torch.empty_like(k) if gk is not None else None
|
|
recompute_w_u_fwd_kernel[(NT, B * H)](
|
|
q=q,
|
|
k=k,
|
|
qg=None,
|
|
kg=kg,
|
|
v=v,
|
|
beta=beta,
|
|
w=w,
|
|
u=u,
|
|
A=A,
|
|
gk=gk,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
K=K,
|
|
V=V,
|
|
BT=BT,
|
|
BK=BK,
|
|
BV=BV,
|
|
DOT_PRECISION="ieee",
|
|
)
|
|
return w, u, None, kg
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
|
|
for BK in [32, 64]
|
|
for BV in [64, 128]
|
|
for num_warps in [2, 4, 8]
|
|
for num_stages in [2, 3, 4]
|
|
],
|
|
key=["BT"],
|
|
)
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def chunk_gla_fwd_kernel_o(
|
|
q,
|
|
v,
|
|
g,
|
|
h,
|
|
o,
|
|
A,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
scale,
|
|
T,
|
|
H: tl.constexpr,
|
|
K: tl.constexpr,
|
|
V: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BK: tl.constexpr,
|
|
BV: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_tg = i_t
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
NT = tl.cdiv(T, BT)
|
|
else:
|
|
NT = tl.cdiv(T, BT)
|
|
i_tg = i_b * NT + i_t
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]
|
|
|
|
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
|
for i_k in range(tl.cdiv(K, BK)):
|
|
p_q = tl.make_block_ptr(
|
|
q + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
p_g = tl.make_block_ptr(
|
|
g + (bos * H + i_h) * K,
|
|
(T, K),
|
|
(H * K, 1),
|
|
(i_t * BT, i_k * BK),
|
|
(BT, BK),
|
|
(1, 0),
|
|
)
|
|
p_h = tl.make_block_ptr(
|
|
h + (i_tg * H + i_h) * K * V,
|
|
(K, V),
|
|
(V, 1),
|
|
(i_k * BK, i_v * BV),
|
|
(BK, BV),
|
|
(1, 0),
|
|
)
|
|
|
|
# [BT, BK]
|
|
b_q = tl.load(p_q, boundary_check=(0, 1))
|
|
b_q = (b_q * scale).to(b_q.dtype)
|
|
# [BT, BK]
|
|
b_g = tl.load(p_g, boundary_check=(0, 1))
|
|
# [BT, BK]
|
|
b_qg = (b_q * exp(b_g)).to(b_q.dtype)
|
|
# [BK, BV]
|
|
b_h = tl.load(p_h, boundary_check=(0, 1))
|
|
# works but dkw, owing to divine benevolence
|
|
# [BT, BV]
|
|
if i_k >= 0:
|
|
b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))
|
|
p_v = tl.make_block_ptr(
|
|
v + (bos * H + i_h) * V,
|
|
(T, V),
|
|
(H * V, 1),
|
|
(i_t * BT, i_v * BV),
|
|
(BT, BV),
|
|
(1, 0),
|
|
)
|
|
p_o = tl.make_block_ptr(
|
|
o + (bos * H + i_h) * V,
|
|
(T, V),
|
|
(H * V, 1),
|
|
(i_t * BT, i_v * BV),
|
|
(BT, BV),
|
|
(1, 0),
|
|
)
|
|
p_A = tl.make_block_ptr(
|
|
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
|
)
|
|
# [BT, BV]
|
|
b_v = tl.load(p_v, boundary_check=(0, 1))
|
|
# [BT, BT]
|
|
b_A = tl.load(p_A, boundary_check=(0, 1))
|
|
b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype)
|
|
b_o += tl.dot(b_A, b_v, allow_tf32=False)
|
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
def chunk_gla_fwd_o_gk(
|
|
q: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
A: torch.Tensor,
|
|
h: torch.Tensor,
|
|
o: torch.Tensor,
|
|
scale: float,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
chunk_size: int = 64,
|
|
):
|
|
B, T, H, K, V = *q.shape, v.shape[-1]
|
|
BT = chunk_size
|
|
|
|
chunk_indices = (
|
|
prepare_chunk_indices(cu_seqlens, chunk_size)
|
|
if cu_seqlens is not None
|
|
else None
|
|
)
|
|
NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
|
|
|
def grid(meta):
|
|
return (cdiv(V, meta["BV"]), NT, B * H)
|
|
|
|
chunk_gla_fwd_kernel_o[grid](
|
|
q=q,
|
|
v=v,
|
|
g=g,
|
|
h=h,
|
|
o=o,
|
|
A=A,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
scale=scale,
|
|
T=T,
|
|
H=H,
|
|
K=K,
|
|
V=V,
|
|
BT=BT,
|
|
)
|
|
return o
|
|
|
|
|
|
def chunk_kda_fwd(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
scale: float,
|
|
initial_state: torch.Tensor,
|
|
output_final_state: bool,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
):
|
|
chunk_size = 64
|
|
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
|
# the intra Aqk is kept in fp32
|
|
# the computation has very marginal effect on the entire throughput
|
|
A, Aqk = chunk_kda_scaled_dot_kkt_fwd(
|
|
q=q,
|
|
k=k,
|
|
gk=g,
|
|
beta=beta,
|
|
scale=scale,
|
|
cu_seqlens=cu_seqlens,
|
|
output_dtype=torch.float32,
|
|
)
|
|
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
|
w, u, _, kg = recompute_w_u_fwd(
|
|
k=k,
|
|
v=v,
|
|
beta=beta,
|
|
A=A,
|
|
gk=g,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
del A
|
|
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
|
k=kg,
|
|
w=w,
|
|
u=u,
|
|
gk=g,
|
|
initial_state=initial_state,
|
|
output_final_state=output_final_state,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
del w, u, kg
|
|
o = chunk_gla_fwd_o_gk(
|
|
q=q,
|
|
v=v_new,
|
|
g=g,
|
|
A=Aqk,
|
|
h=h,
|
|
o=v,
|
|
scale=scale,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_size=chunk_size,
|
|
)
|
|
del Aqk, v_new, h
|
|
return o, final_state
|
|
|
|
|
|
def chunk_kda(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
scale: float = None,
|
|
initial_state: torch.Tensor = None,
|
|
output_final_state: bool = False,
|
|
use_qk_l2norm_in_kernel: bool = False,
|
|
cu_seqlens: torch.LongTensor | None = None,
|
|
**kwargs,
|
|
):
|
|
if scale is None:
|
|
scale = k.shape[-1] ** -0.5
|
|
|
|
if use_qk_l2norm_in_kernel:
|
|
q = l2norm_fwd(q.contiguous())
|
|
k = l2norm_fwd(k.contiguous())
|
|
|
|
o, final_state = chunk_kda_fwd(
|
|
q=q,
|
|
k=k,
|
|
v=v.contiguous(),
|
|
g=g.contiguous(),
|
|
beta=beta.contiguous(),
|
|
scale=scale,
|
|
initial_state=initial_state.contiguous(),
|
|
output_final_state=output_final_state,
|
|
cu_seqlens=cu_seqlens,
|
|
)
|
|
return o, final_state
|
|
|
|
|
|
@triton.autotune(
|
|
configs=[
|
|
triton.Config({"BT": bt}, num_warps=nw, num_stages=ns)
|
|
for bt in BT_LIST_AUTOTUNE
|
|
for nw in NUM_WARPS_AUTOTUNE
|
|
for ns in [2, 3]
|
|
],
|
|
key=["H", "D"],
|
|
)
|
|
@triton.jit
|
|
def kda_gate_fwd_kernel(
|
|
g,
|
|
A,
|
|
y,
|
|
g_bias,
|
|
beta: tl.constexpr,
|
|
threshold: tl.constexpr,
|
|
T,
|
|
H,
|
|
D: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
BD: tl.constexpr,
|
|
HAS_BIAS: tl.constexpr,
|
|
):
|
|
i_t, i_h = tl.program_id(0), tl.program_id(1)
|
|
n_t = i_t * BT
|
|
|
|
b_a = tl.load(A + i_h).to(tl.float32)
|
|
b_a = -tl.exp(b_a)
|
|
|
|
stride_row = H * D
|
|
stride_col = 1
|
|
|
|
g_ptr = tl.make_block_ptr(
|
|
base=g + i_h * D,
|
|
shape=(T, D),
|
|
strides=(stride_row, stride_col),
|
|
offsets=(n_t, 0),
|
|
block_shape=(BT, BD),
|
|
order=(1, 0),
|
|
)
|
|
|
|
y_ptr = tl.make_block_ptr(
|
|
base=y + i_h * D,
|
|
shape=(T, D),
|
|
strides=(stride_row, stride_col),
|
|
offsets=(n_t, 0),
|
|
block_shape=(BT, BD),
|
|
order=(1, 0),
|
|
)
|
|
|
|
b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32)
|
|
|
|
if HAS_BIAS:
|
|
n_d = tl.arange(0, BD)
|
|
bias_mask = n_d < D
|
|
b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to(
|
|
tl.float32
|
|
)
|
|
b_g = b_g + b_bias[None, :]
|
|
|
|
# softplus(x, beta) = (1/beta) * log(1 + exp(beta * x))
|
|
# When beta * x > threshold, use linear approximation x
|
|
# Use threshold to switch to linear when beta*x > threshold
|
|
g_scaled = b_g * beta
|
|
use_linear = g_scaled > threshold
|
|
sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled)))
|
|
b_y = b_a * sp
|
|
|
|
tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1))
|
|
|
|
|
|
def fused_kda_gate(
|
|
g: torch.Tensor,
|
|
A: torch.Tensor,
|
|
head_k_dim: int,
|
|
g_bias: torch.Tensor | None = None,
|
|
beta: float = 1.0,
|
|
threshold: float = 20.0,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Forward pass for KDA gate:
|
|
input g: [..., H*D]
|
|
param A: [H] or [1, 1, H, 1]
|
|
beta: softplus beta parameter
|
|
threshold: softplus threshold parameter
|
|
return : [..., H, D]
|
|
"""
|
|
orig_shape = g.shape[:-1]
|
|
|
|
g = g.view(-1, g.shape[-1])
|
|
T = g.shape[0]
|
|
HD = g.shape[1]
|
|
H = A.numel()
|
|
assert H * head_k_dim == HD
|
|
|
|
y = torch.empty_like(g, dtype=torch.float32)
|
|
|
|
def grid(meta):
|
|
return (cdiv(T, meta["BT"]), H)
|
|
|
|
kda_gate_fwd_kernel[grid](
|
|
g,
|
|
A,
|
|
y,
|
|
g_bias,
|
|
beta,
|
|
threshold,
|
|
T,
|
|
H,
|
|
head_k_dim,
|
|
BD=next_power_of_2(head_k_dim),
|
|
HAS_BIAS=g_bias is not None,
|
|
)
|
|
|
|
y = y.view(*orig_shape, H, head_k_dim)
|
|
return y
|