Files
2026-02-28 11:15:50 +08:00

493 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import cocopod # noqa
import torch
import torch.nn.functional as F
from einops import rearrange
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .l2norm import l2norm_fwd
from .utils import SUPPRESS_LEVEL, input_guard
def torch_solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
output_dtype: torch.dtype = torch.float,
):
chunk_size = 64
A = -A.transpose(1, 2)
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
# A = A.masked_fill(mask, 0)
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
:, :, :sequence_length, :
].transpose(1, 2)
def recompute_w_u_fwd_torch(
k: torch.Tensor, # [B, T, H, K]
v: torch.Tensor, # [B, T, H, V]
beta: torch.Tensor, # [B, T, H]
g: torch.Tensor, # [B, T, H]
A: torch.Tensor, # [B, H, T, T]
):
"""
最简单版本假设等长序列key和value头数相同
"""
chunk_size = 64
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k, v, beta, g, A = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
]
batch_size, num_heads, sequence_length, k_head_dim = k.shape
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
k = F.pad(k, (0, 0, 0, pad_size))
v = F.pad(v, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
v_beta = v * beta.unsqueeze(-1)
k_beta = k * beta.unsqueeze(-1)
k, v, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (k, v, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
u = A @ v_beta
w = A @ (k_beta * g.exp().unsqueeze(-1))
w = (
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
u = (
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
return w, u
def split_by_value(tensor, chunk_size=64):
indices = tensor.tolist()
result = set(indices) # 使用集合避免重复
for i in range(len(indices) - 1):
start = indices[i]
end = indices[i + 1]
# 计算第一个对齐边界
# 我们要找的是 start + n*chunk_size其中n是使结果大于start的最小整数
first_boundary = start + chunk_size
# 在(start, end)范围内插入所有对齐边界
boundary = first_boundary
while boundary < end:
result.add(boundary)
boundary += chunk_size
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
chunk_size = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
)
chunk_offsets = (
prepare_chunk_offsets(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# !
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
g,
chunk_size=64,
reverse=False,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
head_first=False,
)
# !
# A = chunk_scaled_dot_kkt_fwd(k=k,
# beta=beta,
# g_cumsum=g,
# cu_seqlens=cu_seqlens,
# output_dtype=q.dtype)
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
k, beta, g, cu_seqlens, chunk_indices, chunk_size
)
# torch版
# if get_tensor_model_parallel_rank() == 0:
# torch.save(A, "A_in")
# torch.save(cu_seqlens, "cu_seqlens")
# A2 = A.clone()
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
# !
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
# if get_tensor_model_parallel_rank() == 0:
# err = torch.max(torch.abs(A - A2))
# print("err", err)
# if err > 1e-3:
# raise
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
# for i in range(len(cu_seqlens)-1):
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
for i in range(len(cu_seqlens)-1):
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
w_i, u_i = recompute_w_u_fwd_torch(
k=k_i,
v=v_i,
beta=beta_i,
A=A_i,
g=g_i,
)
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
"""
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64,
)
"""
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
)
"""
# i
# import os
# if not os.path.exists("/qwen-next/in"):
# os.makedirs("/qwen-next/in")
# torch.save(k, "/qwen-next/in/k.pt")
# torch.save(u, "/qwen-next/in/u.pt")
# torch.save(w, "/qwen-next/in/w.pt")
# torch.save(g, "/qwen-next/in/g.pt")
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
k,
u,
w,
g,
initial_state,
cu_seqlens,
chunk_indices,
chunk_offsets.to(torch.int32),
chunk_size,
output_final_state,
True,
)
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
# k=k,
# w=w,
# u=u,
# g=g,
# initial_state=initial_state,
# output_final_state=output_final_state,
# cu_seqlens=cu_seqlens,
# )
# if not os.path.exists("/qwen-next/out"):
# os.makedirs("/qwen-next/out")
# torch.save(h, "/qwen-next/out/h.pt")
# torch.save(v_new, "/qwen-next/out/v_new.pt")
# torch.save(final_state, "/qwen-next/out/final_state.pt")
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64,
)
"""
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
"""
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
return g, o, A, final_state, w, h, v_new
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
return o.to(q.dtype), final_state
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert q.dtype == k.dtype == v.dtype
assert (
q.dtype != torch.float32
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert (
len(beta.shape) == 3
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2,
)
q, k, v, beta, g = map(
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
if False:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
g = g.contiguous()
beta = beta.contiguous()
initial_state = initial_state.contiguous()
o = torch.empty_like(v)
final_state = torch.empty_like(initial_state)
import kunlun_ops
kunlun_ops.gated_delta_rule(
q,
k,
v,
initial_state,
g,
beta,
final_state,
o,
scale,
cu_seqlens.cpu(),
cu_seqlens,
cu_seqlens.cpu(),
cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
else:
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
if head_first:
o = rearrange(o, "b t h ... -> b h t ...")
return o, final_state