# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import warnings import torch from einops import rearrange from .chunk_delta_h import chunk_gated_delta_rule_fwd_h from .chunk_o import chunk_fwd_o from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd from .cumsum import chunk_local_cumsum from .l2norm import l2norm_fwd from .solve_tril import solve_tril from .utils import SUPPRESS_LEVEL, input_guard from .wy_fast import recompute_w_u_fwd def chunk_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, ): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. A = chunk_scaled_dot_kkt_fwd( k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 ) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u = recompute_w_u_fwd( k=k, v=v, beta=beta, A=A, g_cumsum=g, cu_seqlens=cu_seqlens, ) h, v_new, final_state = chunk_gated_delta_rule_fwd_h( k=k, w=w, u=u, g=g, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) o = chunk_fwd_o( q=q, k=k, v=v_new, h=h, g=g, scale=scale, cu_seqlens=cu_seqlens, ) if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None elif SUPPRESS_LEVEL >= 3: return g, o, A, final_state, w, h, v_new class ChunkGatedDeltaRuleFunction(torch.autograd.Function): @staticmethod @input_guard @torch.amp.custom_fwd(device_type="cuda") def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k) g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel return o.to(q.dtype), final_state @torch.compiler.disable def chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float = None, initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: torch.LongTensor | None = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, ): r""" Args: q (torch.Tensor): queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. g (torch.Tensor): (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. beta (torch.Tensor): betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. scale (Optional[int]): Scale factor for the RetNet attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): Initial state of shape `[N, H, K, V]` for `N` input sequences. For equal-length input sequences, `N` equals the batch size `B`. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. Default: `False`. Returns: o (torch.Tensor): Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule # inputs with equal lengths >>> B, T, H, K, V = 4, 2048, 4, 512, 512 >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') >>> o, ht = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) >>> o_var, ht_var = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens ) """ assert q.dtype == k.dtype == v.dtype assert q.dtype != torch.float32, ( "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." ) assert len(beta.shape) == 3, ( "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." ) if head_first: raise DeprecationWarning( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", stacklevel=2, ) q, k, v, beta, g = map( lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g) ) if not head_first and q.shape[1] < q.shape[2]: warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", stacklevel=2, ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing." ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = ChunkGatedDeltaRuleFunction.apply( q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel, ) if head_first: o = rearrange(o, "b t h ... -> b h t ...") return o, final_state