# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import torch import torch.nn as nn from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv, next_power_of_2 from .chunk_delta_h import chunk_gated_delta_rule_fwd_h from .cumsum import chunk_local_cumsum from .fused_recurrent import fused_recurrent_gated_delta_rule_fwd_kernel from .index import prepare_chunk_indices from .l2norm import l2norm_fwd from .op import exp, log from .solve_tril import solve_tril from .utils import is_amd BT_LIST_AUTOTUNE = [32, 64, 128] NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32] def fused_recurrent_kda_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.Tensor | None = None, num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] HV = v.shape[2] N = B if cu_seqlens is None else len(cu_seqlens) - 1 BK, BV = next_power_of_2(K), min(next_power_of_2(V), 8) NK, NV = cdiv(K, BK), cdiv(V, BV) assert NK == 1, "NK > 1 is not supported yet" num_stages = 3 num_warps = 1 o = torch.empty_like(k) if inplace_final_state: final_state = initial_state else: final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) stride_init_state_token = initial_state.stride(0) stride_final_state_token = final_state.stride(0) if ssm_state_indices is None: stride_indices_seq, stride_indices_tok = 1, 1 elif ssm_state_indices.ndim == 1: stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 else: stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() grid = (NK, NV, N * HV) fused_recurrent_gated_delta_rule_fwd_kernel[grid]( q=q, k=k, v=v, g=g, beta=beta, o=o, h0=initial_state, ht=final_state, cu_seqlens=cu_seqlens, ssm_state_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, scale=scale, N=N, T=T, B=B, H=H, HV=HV, K=K, V=V, BK=BK, BV=BV, stride_init_state_token=stride_init_state_token, stride_final_state_token=stride_final_state_token, stride_indices_seq=stride_indices_seq, stride_indices_tok=stride_indices_tok, IS_BETA_HEADWISE=beta.ndim == v.ndim, USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, INPLACE_FINAL_STATE=inplace_final_state, IS_KDA=True, num_warps=num_warps, num_stages=num_stages, ) return o, final_state def fused_recurrent_kda( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor = None, scale: float = None, initial_state: torch.Tensor = None, inplace_final_state: bool = True, use_qk_l2norm_in_kernel: bool = True, cu_seqlens: torch.LongTensor | None = None, ssm_state_indices: torch.LongTensor | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor]: if cu_seqlens is not None and q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing." ) if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = fused_recurrent_kda_fwd( q=q.contiguous(), k=k.contiguous(), v=v.contiguous(), g=g.contiguous(), beta=beta.contiguous(), scale=scale, initial_state=initial_state, inplace_final_state=inplace_final_state, cu_seqlens=cu_seqlens, ssm_state_indices=ssm_state_indices, num_accepted_tokens=None, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, ) return o, final_state @triton.heuristics( { "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None, "HAS_RESIDUAL": lambda args: args["residual"] is not None, "HAS_WEIGHT": lambda args: args["w"] is not None, "HAS_BIAS": lambda args: args["b"] is not None, } ) @triton.jit def layer_norm_gated_fwd_kernel( x, # pointer to the input g, # pointer to the gate y, # pointer to the output w, # pointer to the weights b, # pointer to the biases residual, # pointer to the residual residual_out, # pointer to the residual mean, # pointer to the mean rstd, # pointer to the 1/std eps, # epsilon to avoid division by zero T, # number of rows in x D: tl.constexpr, # number of columns in x BT: tl.constexpr, BD: tl.constexpr, ACTIVATION: tl.constexpr, IS_RMS_NORM: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_RESIDUAL: tl.constexpr, HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr, ): i_t = tl.program_id(0) o_d = tl.arange(0, BD) m_d = o_d < D p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) if HAS_RESIDUAL: p_res = tl.make_block_ptr( residual, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0) ) b_x += tl.load(p_res, boundary_check=(0, 1)).to(tl.float32) if STORE_RESIDUAL_OUT: p_res_out = tl.make_block_ptr( residual_out, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0) ) tl.store(p_res_out, b_x.to(p_res_out.dtype.element_ty), boundary_check=(0, 1)) if not IS_RMS_NORM: b_mean = tl.sum(b_x, axis=1) / D p_mean = tl.make_block_ptr(mean, (T,), (1,), (i_t * BT,), (BT,), (0,)) tl.store(p_mean, b_mean.to(p_mean.dtype.element_ty), boundary_check=(0,)) b_xbar = tl.where(m_d[None, :], b_x - b_mean[:, None], 0.0) b_var = tl.sum(b_xbar * b_xbar, axis=1) / D else: b_xbar = tl.where(m_d[None, :], b_x, 0.0) b_var = tl.sum(b_xbar * b_xbar, axis=1) / D b_rstd = 1 / tl.sqrt(b_var + eps) p_rstd = tl.make_block_ptr(rstd, (T,), (1,), (i_t * BT,), (BT,), (0,)) tl.store(p_rstd, b_rstd.to(p_rstd.dtype.element_ty), boundary_check=(0,)) if HAS_WEIGHT: b_w = tl.load(w + o_d, mask=m_d).to(tl.float32) if HAS_BIAS: b_b = tl.load(b + o_d, mask=m_d).to(tl.float32) b_x_hat = ( (b_x - b_mean[:, None]) * b_rstd[:, None] if not IS_RMS_NORM else b_x * b_rstd[:, None] ) b_y = b_x_hat * b_w[None, :] if HAS_WEIGHT else b_x_hat if HAS_BIAS: b_y = b_y + b_b[None, :] # swish/sigmoid output gate p_g = tl.make_block_ptr(g, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) if ACTIVATION == "swish" or ACTIVATION == "silu": b_y = b_y * b_g * tl.sigmoid(b_g) elif ACTIVATION == "sigmoid": b_y = b_y * tl.sigmoid(b_g) # Write output p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) @triton.heuristics( { "STORE_RESIDUAL_OUT": lambda args: args["residual_out"] is not None, "HAS_RESIDUAL": lambda args: args["residual"] is not None, "HAS_WEIGHT": lambda args: args["w"] is not None, "HAS_BIAS": lambda args: args["b"] is not None, } ) @triton.jit def layer_norm_gated_fwd_kernel1( x, # pointer to the input g, # pointer to the gate y, # pointer to the output w, # pointer to the weights b, # pointer to the biases residual, # pointer to the residual residual_out, # pointer to the residual mean, # pointer to the mean rstd, # pointer to the 1/std eps, # epsilon to avoid division by zero D: tl.constexpr, # number of columns in x BD: tl.constexpr, ACTIVATION: tl.constexpr, IS_RMS_NORM: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_RESIDUAL: tl.constexpr, HAS_WEIGHT: tl.constexpr, HAS_BIAS: tl.constexpr, ): i_t = tl.program_id(0) x += i_t * D y += i_t * D g += i_t * D if HAS_RESIDUAL: residual += i_t * D if STORE_RESIDUAL_OUT: residual_out += i_t * D o_d = tl.arange(0, BD) m_d = o_d < D b_x = tl.load(x + o_d, mask=m_d, other=0.0).to(tl.float32) if HAS_RESIDUAL: b_x += tl.load(residual + o_d, mask=m_d, other=0.0).to(tl.float32) if STORE_RESIDUAL_OUT: tl.store(residual_out + o_d, b_x, mask=m_d) if not IS_RMS_NORM: b_mean = tl.sum(b_x, axis=0) / D tl.store(mean + i_t, b_mean) b_xbar = tl.where(m_d, b_x - b_mean, 0.0) b_var = tl.sum(b_xbar * b_xbar, axis=0) / D else: b_xbar = tl.where(m_d, b_x, 0.0) b_var = tl.sum(b_xbar * b_xbar, axis=0) / D b_rstd = 1 / tl.sqrt(b_var + eps) tl.store(rstd + i_t, b_rstd) if HAS_WEIGHT: b_w = tl.load(w + o_d, mask=m_d).to(tl.float32) if HAS_BIAS: b_b = tl.load(b + o_d, mask=m_d).to(tl.float32) b_x_hat = (b_x - b_mean) * b_rstd if not IS_RMS_NORM else b_x * b_rstd b_y = b_x_hat * b_w if HAS_WEIGHT else b_x_hat if HAS_BIAS: b_y = b_y + b_b # swish/sigmoid output gate b_g = tl.load(g + o_d, mask=m_d, other=0.0).to(tl.float32) if ACTIVATION == "swish" or ACTIVATION == "silu": b_y = b_y * b_g * tl.sigmoid(b_g) elif ACTIVATION == "sigmoid": b_y = b_y * tl.sigmoid(b_g) # Write output tl.store(y + o_d, b_y, mask=m_d) def layer_norm_gated_fwd( x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, activation: str = "swish", eps: float = 1e-5, residual: torch.Tensor = None, out_dtype: torch.dtype = None, residual_dtype: torch.dtype = None, is_rms_norm: bool = False, ): if residual is not None: residual_dtype = residual.dtype T, D = x.shape if residual is not None: assert residual.shape == (T, D) if weight is not None: assert weight.shape == (D,) if bias is not None: assert bias.shape == (D,) # allocate output y = x if out_dtype is None else torch.empty_like(x, dtype=out_dtype) if residual is not None or ( residual_dtype is not None and residual_dtype != x.dtype ): residual_out = torch.empty(T, D, device=x.device, dtype=residual_dtype) else: residual_out = None mean = ( torch.empty((T,), dtype=torch.float, device=x.device) if not is_rms_norm else None ) rstd = torch.empty((T,), dtype=torch.float, device=x.device) # Less than 64KB per feature: enqueue fused kernel MAX_FUSED_SIZE = 65536 // x.element_size() BD = min(MAX_FUSED_SIZE, next_power_of_2(D)) if D > BD: raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps if D <= 512: BT = 32 layer_norm_gated_fwd_kernel[(cdiv(T, BT),)]( x=x, g=g, y=y, w=weight, b=bias, residual=residual, residual_out=residual_out, mean=mean, rstd=rstd, eps=eps, T=T, D=D, BD=BD, BT=BT, ACTIVATION=activation, IS_RMS_NORM=is_rms_norm, num_warps=4, ) else: layer_norm_gated_fwd_kernel1[(T,)]( x=x, g=g, y=y, w=weight, b=bias, residual=residual, residual_out=residual_out, mean=mean, rstd=rstd, eps=eps, D=D, BD=BD, ACTIVATION=activation, IS_RMS_NORM=is_rms_norm, num_warps=4, ) # residual_out is None if residual is None and residual_dtype == input_dtype return y, mean, rstd, residual_out if residual_out is not None else x def rms_norm_gated( x: torch.Tensor, g: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, activation: str = "swish", residual: torch.Tensor | None = None, prenorm: bool = False, residual_in_fp32: bool = False, eps: float = 1e-6, ): x_shape_og = x.shape # reshape input data into 2D tensor x = x.contiguous().reshape(-1, x.shape[-1]) g = g.contiguous().reshape(-1, g.shape[-1]) if residual is not None: assert residual.shape == x_shape_og residual = residual.contiguous().reshape(-1, residual.shape[-1]) residual_dtype = ( residual.dtype if residual is not None else (torch.float if residual_in_fp32 else None) ) y, _, _, residual_out = layer_norm_gated_fwd( x=x, g=g, weight=weight, bias=bias, activation=activation, eps=eps, residual=residual, residual_dtype=residual_dtype, is_rms_norm=True, ) y = y.reshape(x_shape_og) return y if not prenorm else (y, residual_out.reshape(x_shape_og)) class FusedRMSNormGated(nn.Module): def __init__( self, hidden_size: int, elementwise_affine: bool = True, eps: float = 1e-5, activation: str = "swish", device: torch.device | None = None, dtype: torch.dtype | None = None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.hidden_size = hidden_size self.elementwise_affine = elementwise_affine self.eps = eps self.activation = activation if self.activation not in ["swish", "silu", "sigmoid"]: raise ValueError(f"Unsupported activation: {self.activation}") if elementwise_affine: self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) else: self.register_parameter("weight", None) self.register_parameter("bias", None) def forward( self, x: torch.Tensor, g: torch.Tensor, residual: torch.Tensor | None = None, prenorm: bool = False, residual_in_fp32: bool = False, ) -> torch.Tensor: return rms_norm_gated( x, g, self.weight, self.bias, self.activation, residual=residual, eps=self.eps, prenorm=prenorm, residual_in_fp32=residual_in_fp32, ) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( configs=[ triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) for BK in [32, 64] for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4] ], key=["BC"], ) @triton.jit(do_not_specialize=["T"]) def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter( q, k, g, beta, A, Aqk, scale, cu_seqlens, chunk_indices, T, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, NC: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H i_i, i_j = i_c // NC, i_c % NC if IS_VARLEN: i_n, i_t = ( tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), ) bos, eos = ( tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32), ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if i_t * BT + i_i * BC >= T: return if i_i <= i_j: return q += (bos * H + i_h) * K k += (bos * H + i_h) * K g += (bos * H + i_h) * K A += (bos * H + i_h) * BT Aqk += (bos * H + i_h) * BT p_b = tl.make_block_ptr( beta + bos * H + i_h, (T,), (H,), (i_t * BT + i_i * BC,), (BC,), (0,) ) b_b = tl.load(p_b, boundary_check=(0,)) b_A = tl.zeros([BC, BC], dtype=tl.float32) b_Aqk = tl.zeros([BC, BC], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): p_q = tl.make_block_ptr( q, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) ) p_k = tl.make_block_ptr( k, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) ) p_g = tl.make_block_ptr( g, (T, K), (H * K, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0) ) b_kt = tl.make_block_ptr( k, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1) ) p_gk = tl.make_block_ptr( g, (K, T), (1, H * K), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1) ) o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K # [BK,] b_gn = tl.load(g + (i_t * BT + i_i * BC) * H * K + o_k, mask=m_k, other=0) # [BC, BK] b_g = tl.load(p_g, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) * exp(b_g - b_gn[None, :]) # [BK, BC] b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_kt = tl.load(b_kt, boundary_check=(0, 1)) # [BC, BC] b_ktg = b_kt * exp(b_gn[:, None] - b_gk) b_A += tl.dot(b_k, b_ktg) b_q = tl.load(p_q, boundary_check=(0, 1)) b_qg = b_q * exp(b_g - b_gn[None, :]) * scale b_Aqk += tl.dot(b_qg, b_ktg) b_A *= b_b[:, None] p_A = tl.make_block_ptr( A, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0) ) tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1)) p_Aqk = tl.make_block_ptr( Aqk, (T, BT), (H * BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0) ) tl.store(p_Aqk, b_Aqk.to(Aqk.dtype.element_ty), boundary_check=(0, 1)) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], key=["BK", "BT"], ) @triton.jit(do_not_specialize=["T"]) def chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra( q, k, g, beta, A, Aqk, scale, cu_seqlens, chunk_indices, T, H: tl.constexpr, K: tl.constexpr, BT: tl.constexpr, BC: tl.constexpr, BK: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_n, i_t = ( tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), ) bos, eos = ( tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32), ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T if i_t * BT + i_i * BC >= T: return o_i = tl.arange(0, BC) o_k = tl.arange(0, BK) m_k = o_k < K m_A = (i_t * BT + i_i * BC + o_i) < T o_A = (bos + i_t * BT + i_i * BC + o_i) * H * BT + i_h * BT + i_i * BC p_q = tl.make_block_ptr( q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0), ) p_k = tl.make_block_ptr( k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0), ) p_g = tl.make_block_ptr( g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0), ) b_q = tl.load(p_q, boundary_check=(0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_g = tl.load(p_g, boundary_check=(0, 1)) p_b = beta + (bos + i_t * BT + i_i * BC + o_i) * H + i_h b_k = b_k * tl.load(p_b, mask=m_A, other=0)[:, None] p_kt = k + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k p_gk = g + (bos + i_t * BT + i_i * BC) * H * K + i_h * K + o_k for j in range(0, min(BC, T - i_t * BT - i_i * BC)): b_kt = tl.load(p_kt, mask=m_k, other=0).to(tl.float32) b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32) b_ktg = b_kt[None, :] * exp(b_g - b_gk[None, :]) b_A = tl.sum(b_k * b_ktg, 1) b_A = tl.where(o_i > j, b_A, 0.0) b_Aqk = tl.sum(b_q * b_ktg, 1) b_Aqk = tl.where(o_i >= j, b_Aqk * scale, 0.0) tl.store(A + o_A + j, b_A, mask=m_A) tl.store(Aqk + o_A + j, b_Aqk, mask=m_A) p_kt += H * K p_gk += H * K def chunk_kda_scaled_dot_kkt_fwd( q: torch.Tensor, k: torch.Tensor, gk: torch.Tensor | None = None, beta: torch.Tensor | None = None, scale: float | None = None, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, ) -> tuple[torch.Tensor, torch.Tensor]: r""" Compute beta * K * K^T. Args: k (torch.Tensor): The key tensor of shape `[B, T, H, K]`. beta (torch.Tensor): The beta tensor of shape `[B, T, H]`. gk (torch.Tensor): The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`. cu_seqlens (torch.LongTensor): The cumulative sequence lengths of the input tensor. Default: None chunk_size (int): The chunk size. Default: 64. output_dtype (torch.dtype): The dtype of the output tensor. Default: `torch.float32` Returns: beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. """ B, T, H, K = k.shape assert K <= 256 BT = chunk_size chunk_indices = ( prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None ) NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) BC = min(16, BT) NC = cdiv(BT, BC) BK = max(next_power_of_2(K), 16) A = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype) Aqk = torch.zeros(B, T, H, BT, device=k.device, dtype=output_dtype) grid = (NT, NC * NC, B * H) chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_inter[grid]( q=q, k=k, g=gk, beta=beta, A=A, Aqk=Aqk, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, K=K, BT=BT, BC=BC, NC=NC, ) grid = (NT, NC, B * H) chunk_kda_scaled_dot_kkt_fwd_kernel_intra_sub_intra[grid]( q=q, k=k, g=gk, beta=beta, A=A, Aqk=Aqk, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, K=K, BT=BT, BC=BC, BK=BK, ) return A, Aqk @triton.heuristics( { "STORE_QG": lambda args: args["qg"] is not None, "STORE_KG": lambda args: args["kg"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], ) @triton.jit(do_not_specialize=["T"]) def recompute_w_u_fwd_kernel( q, k, qg, kg, v, beta, w, u, A, gk, cu_seqlens, chunk_indices, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, STORE_QG: tl.constexpr, STORE_KG: tl.constexpr, IS_VARLEN: tl.constexpr, DOT_PRECISION: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_n, i_t = ( tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), ) bos, eos = ( tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32), ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_b = tl.load(p_b, boundary_check=(0,)) p_A = tl.make_block_ptr( A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) ) b_A = tl.load(p_A, boundary_check=(0, 1)) for i_v in range(tl.cdiv(V, BV)): p_v = tl.make_block_ptr( v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0), ) p_u = tl.make_block_ptr( u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0), ) b_v = tl.load(p_v, boundary_check=(0, 1)) b_vb = (b_v * b_b[:, None]).to(b_v.dtype) b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION) tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) for i_k in range(tl.cdiv(K, BK)): p_w = tl.make_block_ptr( w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) p_k = tl.make_block_ptr( k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = b_k * b_b[:, None] p_gk = tl.make_block_ptr( gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) b_gk = tl.load(p_gk, boundary_check=(0, 1)) b_kb *= exp(b_gk) if STORE_QG: p_q = tl.make_block_ptr( q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) p_qg = tl.make_block_ptr( qg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) b_q = tl.load(p_q, boundary_check=(0, 1)) b_qg = b_q * exp(b_gk) tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) if STORE_KG: last_idx = min(i_t * BT + BT, T) - 1 o_k = i_k * BK + tl.arange(0, BK) m_k = o_k < K b_gn = tl.load( gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0 ) b_kg = b_k * exp(b_gn - b_gk) p_kg = tl.make_block_ptr( kg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) def recompute_w_u_fwd( k: torch.Tensor, v: torch.Tensor, beta: torch.Tensor, A: torch.Tensor, q: torch.Tensor | None = None, gk: torch.Tensor | None = None, cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] BT = A.shape[-1] BK = 64 BV = 64 chunk_indices = ( prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None ) NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) w = torch.empty_like(k) u = torch.empty_like(v) kg = torch.empty_like(k) if gk is not None else None recompute_w_u_fwd_kernel[(NT, B * H)]( q=q, k=k, qg=None, kg=kg, v=v, beta=beta, w=w, u=u, A=A, gk=gk, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, K=K, V=V, BT=BT, BK=BK, BV=BV, DOT_PRECISION="ieee", ) return w, u, None, kg @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @triton.autotune( configs=[ triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) for BK in [32, 64] for BV in [64, 128] for num_warps in [2, 4, 8] for num_stages in [2, 3, 4] ], key=["BT"], ) @triton.jit(do_not_specialize=["T"]) def chunk_gla_fwd_kernel_o( q, v, g, h, o, A, cu_seqlens, chunk_indices, scale, T, H: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, IS_VARLEN: tl.constexpr, ): i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: i_tg = i_t i_n, i_t = ( tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), ) bos, eos = ( tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32), ) T = eos - bos NT = tl.cdiv(T, BT) else: NT = tl.cdiv(T, BT) i_tg = i_b * NT + i_t bos, eos = i_b * T, i_b * T + T m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] b_o = tl.zeros([BT, BV], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): p_q = tl.make_block_ptr( q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) p_g = tl.make_block_ptr( g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0), ) p_h = tl.make_block_ptr( h + (i_tg * H + i_h) * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0), ) # [BT, BK] b_q = tl.load(p_q, boundary_check=(0, 1)) b_q = (b_q * scale).to(b_q.dtype) # [BT, BK] b_g = tl.load(p_g, boundary_check=(0, 1)) # [BT, BK] b_qg = (b_q * exp(b_g)).to(b_q.dtype) # [BK, BV] b_h = tl.load(p_h, boundary_check=(0, 1)) # works but dkw, owing to divine benevolence # [BT, BV] if i_k >= 0: b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) p_v = tl.make_block_ptr( v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0), ) p_o = tl.make_block_ptr( o + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0), ) p_A = tl.make_block_ptr( A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0) ) # [BT, BV] b_v = tl.load(p_v, boundary_check=(0, 1)) # [BT, BT] b_A = tl.load(p_A, boundary_check=(0, 1)) b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype) b_o += tl.dot(b_A, b_v, allow_tf32=False) tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) def chunk_gla_fwd_o_gk( q: torch.Tensor, v: torch.Tensor, g: torch.Tensor, A: torch.Tensor, h: torch.Tensor, o: torch.Tensor, scale: float, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, ): B, T, H, K, V = *q.shape, v.shape[-1] BT = chunk_size chunk_indices = ( prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None ) NT = cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) def grid(meta): return (cdiv(V, meta["BV"]), NT, B * H) chunk_gla_fwd_kernel_o[grid]( q=q, v=v, g=g, h=h, o=o, A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, scale=scale, T=T, H=H, K=K, V=V, BT=BT, ) return o def chunk_kda_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, ): chunk_size = 64 g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) # the intra Aqk is kept in fp32 # the computation has very marginal effect on the entire throughput A, Aqk = chunk_kda_scaled_dot_kkt_fwd( q=q, k=k, gk=g, beta=beta, scale=scale, cu_seqlens=cu_seqlens, output_dtype=torch.float32, ) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u, _, kg = recompute_w_u_fwd( k=k, v=v, beta=beta, A=A, gk=g, cu_seqlens=cu_seqlens, ) del A h, v_new, final_state = chunk_gated_delta_rule_fwd_h( k=kg, w=w, u=u, gk=g, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) del w, u, kg o = chunk_gla_fwd_o_gk( q=q, v=v_new, g=g, A=Aqk, h=h, o=v, scale=scale, cu_seqlens=cu_seqlens, chunk_size=chunk_size, ) del Aqk, v_new, h return o, final_state def chunk_kda( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float = None, initial_state: torch.Tensor = None, output_final_state: bool = False, use_qk_l2norm_in_kernel: bool = False, cu_seqlens: torch.LongTensor | None = None, **kwargs, ): if scale is None: scale = k.shape[-1] ** -0.5 if use_qk_l2norm_in_kernel: q = l2norm_fwd(q.contiguous()) k = l2norm_fwd(k.contiguous()) o, final_state = chunk_kda_fwd( q=q, k=k, v=v.contiguous(), g=g.contiguous(), beta=beta.contiguous(), scale=scale, initial_state=initial_state.contiguous(), output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) return o, final_state @triton.autotune( configs=[ triton.Config({"BT": bt}, num_warps=nw, num_stages=ns) for bt in BT_LIST_AUTOTUNE for nw in NUM_WARPS_AUTOTUNE for ns in [2, 3] ], key=["H", "D"], ) @triton.jit def kda_gate_fwd_kernel( g, A, y, g_bias, beta: tl.constexpr, threshold: tl.constexpr, T, H, D: tl.constexpr, BT: tl.constexpr, BD: tl.constexpr, HAS_BIAS: tl.constexpr, ): i_t, i_h = tl.program_id(0), tl.program_id(1) n_t = i_t * BT b_a = tl.load(A + i_h).to(tl.float32) b_a = -tl.exp(b_a) stride_row = H * D stride_col = 1 g_ptr = tl.make_block_ptr( base=g + i_h * D, shape=(T, D), strides=(stride_row, stride_col), offsets=(n_t, 0), block_shape=(BT, BD), order=(1, 0), ) y_ptr = tl.make_block_ptr( base=y + i_h * D, shape=(T, D), strides=(stride_row, stride_col), offsets=(n_t, 0), block_shape=(BT, BD), order=(1, 0), ) b_g = tl.load(g_ptr, boundary_check=(0, 1)).to(tl.float32) if HAS_BIAS: n_d = tl.arange(0, BD) bias_mask = n_d < D b_bias = tl.load(g_bias + i_h * D + n_d, mask=bias_mask, other=0.0).to( tl.float32 ) b_g = b_g + b_bias[None, :] # softplus(x, beta) = (1/beta) * log(1 + exp(beta * x)) # When beta * x > threshold, use linear approximation x # Use threshold to switch to linear when beta*x > threshold g_scaled = b_g * beta use_linear = g_scaled > threshold sp = tl.where(use_linear, b_g, (1.0 / beta) * log(1.0 + tl.exp(g_scaled))) b_y = b_a * sp tl.store(y_ptr, b_y.to(y.dtype.element_ty), boundary_check=(0, 1)) def fused_kda_gate( g: torch.Tensor, A: torch.Tensor, head_k_dim: int, g_bias: torch.Tensor | None = None, beta: float = 1.0, threshold: float = 20.0, ) -> torch.Tensor: """ Forward pass for KDA gate: input g: [..., H*D] param A: [H] or [1, 1, H, 1] beta: softplus beta parameter threshold: softplus threshold parameter return : [..., H, D] """ orig_shape = g.shape[:-1] g = g.view(-1, g.shape[-1]) T = g.shape[0] HD = g.shape[1] H = A.numel() assert H * head_k_dim == HD y = torch.empty_like(g, dtype=torch.float32) def grid(meta): return (cdiv(T, meta["BT"]), H) kda_gate_fwd_kernel[grid]( g, A, y, g_bias, beta, threshold, T, H, head_k_dim, BD=next_power_of_2(head_k_dim), HAS_BIAS=g_bias is not None, ) y = y.view(*orig_shape, H, head_k_dim) return y