# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 from typing import Optional import torch import xtorch_ops class FusedRecurrentFunction(torch.autograd.Function): @staticmethod def forward(ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, cu_seqlens: Optional[torch.LongTensor] = None, ssm_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, use_qk_l2norm_in_kernel: bool = False): o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2( q.contiguous(), k.contiguous(), v.contiguous(), g.contiguous(), beta.contiguous(), scale, initial_state, inplace_final_state=inplace_final_state, cu_seqlens=cu_seqlens, h0_indices=ssm_state_indices, num_accepted_tokens=num_accepted_tokens, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, is_h0_transposed=True ) return o, final_state def fused_recurrent_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor = None, scale: float = None, initial_state: torch.Tensor = None, inplace_final_state: bool = True, cu_seqlens: Optional[torch.LongTensor] = None, ssm_state_indices: Optional[torch.Tensor] = None, num_accepted_tokens: Optional[torch.Tensor] = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: r""" Args: q (torch.Tensor): queries of shape `[B, T, H, K]`. k (torch.Tensor): keys of shape `[B, T, H, K]`. v (torch.Tensor): values of shape `[B, T, HV, V]`. GVA is applied if `HV > H`. g (torch.Tensor): g (decays) of shape `[B, T, HV]`. beta (torch.Tensor): betas of shape `[B, T, HV]`. scale (Optional[int]): Scale factor for the RetNet attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): Initial state of shape `[N, HV, K, V]` for `N` input sequences. For equal-length input sequences, `N` equals the batch size `B`. Default: `None`. inplace_final_state: bool: Whether to store the final state in-place to save memory. Default: `True`. cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. ssm_state_indices (Optional[torch.Tensor]): Indices to map the input sequences to the initial/final states. num_accepted_tokens (Optional[torch.Tensor]): Number of accepted tokens for each sequence during decoding. Returns: o (torch.Tensor): Outputs of shape `[B, T, HV, V]`. final_state (torch.Tensor): Final state of shape `[N, HV, K, V]`. Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule # inputs with equal lengths >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 >>> q = torch.randn(B, T, H, K, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, HV, V, device='cuda') >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() >>> h0 = torch.randn(B, HV, K, V, device='cuda') >>> o, ht = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) >>> o_var, ht_var = fused_gated_recurrent_delta_rule( q, k, v, g, beta, initial_state=h0, cu_seqlens=cu_seqlens ) """ if cu_seqlens is not None and q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing.") if scale is None: scale = k.shape[-1]**-0.5 else: assert scale > 0, "scale must be positive" if beta is None: beta = torch.ones_like(q[..., 0]) o, final_state = FusedRecurrentFunction.apply( q, k, v, g, beta, scale, initial_state, inplace_final_state, cu_seqlens, ssm_state_indices, num_accepted_tokens, use_qk_l2norm_in_kernel, ) return o, final_state