# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Songlin Yang, Yu Zhang # # This file contains code copied from the flash-linear-attention project. # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import warnings from typing import Optional import cocopod # noqa import torch import torch.nn.functional as F from einops import rearrange from .index import prepare_chunk_indices, prepare_chunk_offsets from .l2norm import l2norm_fwd from .utils import SUPPRESS_LEVEL, input_guard def torch_solve_tril( A: torch.Tensor, cu_seqlens: Optional[torch.LongTensor] = None, output_dtype: torch.dtype = torch.float, ): chunk_size = 64 A = -A.transpose(1, 2) sequence_length = A.shape[-2] pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size A = F.pad(A, (0, 0, 0, pad_size)) A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1]) # mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0) # A = A.masked_fill(mask, 0) for i in range(1, chunk_size): row = A[..., i, :i].clone() sub = A[..., :i, :i].clone() A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device) return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[ :, :, :sequence_length, : ].transpose(1, 2) def recompute_w_u_fwd_torch( k: torch.Tensor, # [B, T, H, K] v: torch.Tensor, # [B, T, H, V] beta: torch.Tensor, # [B, T, H] g: torch.Tensor, # [B, T, H] A: torch.Tensor, # [B, H, T, T] ): """ 最简单版本:假设等长序列,key和value头数相同 """ chunk_size = 64 num_v_heads, num_k_heads = v.shape[2], k.shape[2] k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2) k, v, beta, g, A = [ x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A) ] batch_size, num_heads, sequence_length, k_head_dim = k.shape pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size k = F.pad(k, (0, 0, 0, pad_size)) v = F.pad(v, (0, 0, 0, pad_size)) beta = F.pad(beta, (0, pad_size)) g = F.pad(g, (0, pad_size)) A = F.pad(A, (0, 0, 0, pad_size)) A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1]) v_beta = v * beta.unsqueeze(-1) k_beta = k * beta.unsqueeze(-1) k, v, k_beta, v_beta = [ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (k, v, k_beta, v_beta) ] g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) u = A @ v_beta w = A @ (k_beta * g.exp().unsqueeze(-1)) w = ( w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :] .transpose(1, 2) .contiguous() ) u = ( u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :] .transpose(1, 2) .contiguous() ) return w, u def split_by_value(tensor, chunk_size=64): indices = tensor.tolist() result = set(indices) # 使用集合避免重复 for i in range(len(indices) - 1): start = indices[i] end = indices[i + 1] # 计算第一个对齐边界 # 我们要找的是 start + n*chunk_size,其中n是使结果大于start的最小整数 first_boundary = start + chunk_size # 在(start, end)范围内插入所有对齐边界 boundary = first_boundary while boundary < end: result.add(boundary) boundary += chunk_size return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device) def chunk_gated_delta_rule_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: Optional[torch.LongTensor] = None, ): chunk_size = 64 chunk_indices = ( prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None ) chunk_offsets = ( prepare_chunk_offsets(cu_seqlens, chunk_size) if cu_seqlens is not None else None ) # ! # g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) g = torch.ops.xspeedgate_ops.chunk_local_cumsum( g, chunk_size=64, reverse=False, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, head_first=False, ) # ! # A = chunk_scaled_dot_kkt_fwd(k=k, # beta=beta, # g_cumsum=g, # cu_seqlens=cu_seqlens, # output_dtype=q.dtype) A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd( k, beta, g, cu_seqlens, chunk_indices, chunk_size ) # torch版 # if get_tensor_model_parallel_rank() == 0: # torch.save(A, "A_in") # torch.save(cu_seqlens, "cu_seqlens") # A2 = A.clone() torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size) # ! # torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens) # if get_tensor_model_parallel_rank() == 0: # err = torch.max(torch.abs(A - A2)) # print("err", err) # if err > 1e-3: # raise # A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) # for i in range(len(cu_seqlens)-1): # A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] # A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype) """ B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] u = torch.empty_like(v) w = k.new_empty(B, T, H, K) for i in range(len(cu_seqlens)-1): k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :] A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :] w_i, u_i = recompute_w_u_fwd_torch( k=k_i, v=v_i, beta=beta_i, A=A_i, g=g_i, ) w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i """ w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd( k=k, v=v, beta=beta, A=A, g_cumsum=g, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, chunk_size=64, ) """ w, u = recompute_w_u_fwd( k=k, v=v, beta=beta, A=A, g_cumsum=g, cu_seqlens=cu_seqlens, ) """ # i # import os # if not os.path.exists("/qwen-next/in"): # os.makedirs("/qwen-next/in") # torch.save(k, "/qwen-next/in/k.pt") # torch.save(u, "/qwen-next/in/u.pt") # torch.save(w, "/qwen-next/in/w.pt") # torch.save(g, "/qwen-next/in/g.pt") # torch.save(initial_state, "/qwen-next/in/initial_state.pt") # torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt") # torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt") # torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt") # torch.save(chunk_size, "/qwen-next/in/chunk_size.pt") # torch.save(output_final_state, "/qwen-next/in/output_final_state.pt") h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h( k, u, w, g, initial_state, cu_seqlens, chunk_indices, chunk_offsets.to(torch.int32), chunk_size, output_final_state, True, ) # h, v_new, final_state = chunk_gated_delta_rule_fwd_h( # k=k, # w=w, # u=u, # g=g, # initial_state=initial_state, # output_final_state=output_final_state, # cu_seqlens=cu_seqlens, # ) # if not os.path.exists("/qwen-next/out"): # os.makedirs("/qwen-next/out") # torch.save(h, "/qwen-next/out/h.pt") # torch.save(v_new, "/qwen-next/out/v_new.pt") # torch.save(final_state, "/qwen-next/out/final_state.pt") o = torch.ops.xspeedgate_ops.chunk_fwd_o( q=q, k=k, v=v_new, h=h, g=g, scale=scale, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, chunk_size=64, ) """ o = chunk_fwd_o( q=q, k=k, v=v_new, h=h, g=g, scale=scale, cu_seqlens=cu_seqlens, ) """ if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None elif SUPPRESS_LEVEL >= 3: return g, o, A, final_state, w, h, v_new class ChunkGatedDeltaRuleFunction(torch.autograd.Function): @staticmethod @input_guard @torch.amp.custom_fwd(device_type="cuda") def forward( ctx, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float, initial_state: torch.Tensor, output_final_state: bool, cu_seqlens: Optional[torch.LongTensor] = None, use_qk_l2norm_in_kernel: bool = False, ): if use_qk_l2norm_in_kernel: q = l2norm_fwd(q) k = l2norm_fwd(k) g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, ) ctx.scale = scale ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel return o.to(q.dtype), final_state @torch.compiler.disable def chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, scale: float = None, initial_state: torch.Tensor = None, output_final_state: bool = False, cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, ): r""" Args: q (torch.Tensor): queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. k (torch.Tensor): keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. v (torch.Tensor): values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. g (torch.Tensor): (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. beta (torch.Tensor): betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. scale (Optional[int]): Scale factor for the RetNet attention scores. If not provided, it will default to `1 / sqrt(K)`. Default: `None`. initial_state (Optional[torch.Tensor]): Initial state of shape `[N, H, K, V]` for `N` input sequences. For equal-length input sequences, `N` equals the batch size `B`. Default: `None`. output_final_state (Optional[bool]): Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. head_first (Optional[bool]): Whether the inputs are in the head-first format, which is not supported for variable-length inputs. Default: `False`. Returns: o (torch.Tensor): Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. final_state (torch.Tensor): Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. Examples:: >>> import torch >>> import torch.nn.functional as F >>> from einops import rearrange >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule # inputs with equal lengths >>> B, T, H, K, V = 4, 2048, 4, 512, 512 >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') >>> o, ht = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True ) # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) >>> o_var, ht_var = chunk_gated_delta_rule( q, k, v, g, beta, initial_state=h0, output_final_state=True, cu_seqlens=cu_seqlens ) """ assert q.dtype == k.dtype == v.dtype assert ( q.dtype != torch.float32 ), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." assert ( len(beta.shape) == 3 ), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." if head_first: raise DeprecationWarning( "head_first is deprecated and will be removed in a future version. " "Please use head_first=False for now instead.", stacklevel=2, ) q, k, v, beta, g = map( lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g) ) if not head_first and q.shape[1] < q.shape[2]: warnings.warn( f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", stacklevel=2, ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." f"Please flatten variable-length inputs before processing." ) if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: raise ValueError( f"The number of initial states is expected to be equal to the number of input sequences, " f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." ) if scale is None: scale = k.shape[-1] ** -0.5 if False: q = q.contiguous() k = k.contiguous() v = v.contiguous() g = g.contiguous() beta = beta.contiguous() initial_state = initial_state.contiguous() o = torch.empty_like(v) final_state = torch.empty_like(initial_state) import kunlun_ops kunlun_ops.gated_delta_rule( q, k, v, initial_state, g, beta, final_state, o, scale, cu_seqlens.cpu(), cu_seqlens, cu_seqlens.cpu(), cu_seqlens, use_qk_l2norm_in_kernel=True, ) else: o, final_state = ChunkGatedDeltaRuleFunction.apply( q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel, ) if head_first: o = rearrange(o, "b t h ... -> b h t ...") return o, final_state