Files
xc-llm-kunlun/vllm_kunlun/ops/fla/chunk.py

493 lines
16 KiB
Python
Raw Normal View History

2025-12-10 17:51:24 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import cocopod # noqa
2025-12-10 17:51:24 +08:00
import torch
import torch.nn.functional as F
2025-12-10 17:51:24 +08:00
from einops import rearrange
from .index import prepare_chunk_indices, prepare_chunk_offsets
2025-12-10 17:51:24 +08:00
from .l2norm import l2norm_fwd
from .utils import SUPPRESS_LEVEL, input_guard
2025-12-10 17:51:24 +08:00
def torch_solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor] = None,
output_dtype: torch.dtype = torch.float,
):
chunk_size = 64
A = -A.transpose(1, 2)
2025-12-10 17:51:24 +08:00
sequence_length = A.shape[-2]
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
# mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=A.device), diagonal=0)
2025-12-10 17:51:24 +08:00
# A = A.masked_fill(mask, 0)
2025-12-10 17:51:24 +08:00
for i in range(1, chunk_size):
row = A[..., i, :i].clone()
sub = A[..., :i, :i].clone()
A[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
A = A + torch.eye(chunk_size, dtype=A.dtype, device=A.device)
return A.reshape(A.shape[0], A.shape[1], -1, A.shape[-1])[
:, :, :sequence_length, :
].transpose(1, 2)
def recompute_w_u_fwd_torch(
k: torch.Tensor, # [B, T, H, K]
v: torch.Tensor, # [B, T, H, V]
beta: torch.Tensor, # [B, T, H]
g: torch.Tensor, # [B, T, H]
A: torch.Tensor, # [B, H, T, T]
):
"""
最简单版本假设等长序列key和value头数相同
"""
chunk_size = 64
num_v_heads, num_k_heads = v.shape[2], k.shape[2]
k = k.repeat_interleave(num_v_heads // num_k_heads, dim=2)
k, v, beta, g, A = [
x.transpose(1, 2).contiguous().to(torch.float32) for x in (k, v, beta, g, A)
]
batch_size, num_heads, sequence_length, k_head_dim = k.shape
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
k = F.pad(k, (0, 0, 0, pad_size))
v = F.pad(v, (0, 0, 0, pad_size))
beta = F.pad(beta, (0, pad_size))
g = F.pad(g, (0, pad_size))
A = F.pad(A, (0, 0, 0, pad_size))
A = A.reshape(A.shape[0], A.shape[1], -1, chunk_size, A.shape[-1])
v_beta = v * beta.unsqueeze(-1)
k_beta = k * beta.unsqueeze(-1)
k, v, k_beta, v_beta = [
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
for x in (k, v, k_beta, v_beta)
]
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
u = A @ v_beta
w = A @ (k_beta * g.exp().unsqueeze(-1))
w = (
w.reshape(w.shape[0], w.shape[1], -1, w.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
u = (
u.reshape(u.shape[0], u.shape[1], -1, u.shape[-1])[:, :, :sequence_length, :]
.transpose(1, 2)
.contiguous()
)
return w, u
def split_by_value(tensor, chunk_size=64):
indices = tensor.tolist()
result = set(indices) # 使用集合避免重复
for i in range(len(indices) - 1):
start = indices[i]
end = indices[i + 1]
# 计算第一个对齐边界
# 我们要找的是 start + n*chunk_size其中n是使结果大于start的最小整数
first_boundary = start + chunk_size
# 在(start, end)范围内插入所有对齐边界
boundary = first_boundary
while boundary < end:
result.add(boundary)
boundary += chunk_size
return torch.tensor(sorted(result), dtype=tensor.dtype, device=tensor.device)
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
):
chunk_size = 64
chunk_indices = (
prepare_chunk_indices(cu_seqlens, 64) if cu_seqlens is not None else None
)
chunk_offsets = (
prepare_chunk_offsets(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# !
# g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = torch.ops.xspeedgate_ops.chunk_local_cumsum(
g,
chunk_size=64,
reverse=False,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
head_first=False,
)
# !
# A = chunk_scaled_dot_kkt_fwd(k=k,
# beta=beta,
# g_cumsum=g,
# cu_seqlens=cu_seqlens,
# output_dtype=q.dtype)
A = torch.ops.xspeedgate_ops.chunk_scaled_dot_kkt_fwd(
k, beta, g, cu_seqlens, chunk_indices, chunk_size
)
# torch版
# if get_tensor_model_parallel_rank() == 0:
# torch.save(A, "A_in")
# torch.save(cu_seqlens, "cu_seqlens")
# A2 = A.clone()
torch.ops.xspeedgate_ops.solve_tril_ns(A, cu_seqlens, chunk_indices, chunk_size)
# !
# torch.ops.xspeedgate_ops.solve_tril_fwd(A, cu_seqlens)
# if get_tensor_model_parallel_rank() == 0:
# err = torch.max(torch.abs(A - A2))
# print("err", err)
# if err > 1e-3:
# raise
# A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
# for i in range(len(cu_seqlens)-1):
# A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
# A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = torch_solve_tril(A=A_i, cu_seqlens=torch.tensor([0, cu_seqlens[i+1]-cu_seqlens[i]], device=q.device), output_dtype=k.dtype)
"""
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
for i in range(len(cu_seqlens)-1):
k_i = k[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
v_i = v[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
beta_i = beta[:, cu_seqlens[i]:cu_seqlens[i+1], :]
A_i = A[:, cu_seqlens[i]:cu_seqlens[i+1], :, :]
g_i = g[:, cu_seqlens[i]:cu_seqlens[i+1], :]
w_i, u_i = recompute_w_u_fwd_torch(
k=k_i,
v=v_i,
beta=beta_i,
A=A_i,
g=g_i,
)
w[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = w_i
u[:, cu_seqlens[i]:cu_seqlens[i+1], :, :] = u_i
"""
w, u = torch.ops.xspeedgate_ops.recompute_w_u_fwd(
2025-12-10 17:51:24 +08:00
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64,
2025-12-10 17:51:24 +08:00
)
"""
w, u = recompute_w_u_fwd(
2025-12-10 17:51:24 +08:00
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
2025-12-10 17:51:24 +08:00
cu_seqlens=cu_seqlens,
)
"""
# i
# import os
# if not os.path.exists("/qwen-next/in"):
# os.makedirs("/qwen-next/in")
# torch.save(k, "/qwen-next/in/k.pt")
# torch.save(u, "/qwen-next/in/u.pt")
# torch.save(w, "/qwen-next/in/w.pt")
# torch.save(g, "/qwen-next/in/g.pt")
# torch.save(initial_state, "/qwen-next/in/initial_state.pt")
# torch.save(cu_seqlens, "/qwen-next/in/cu_seqlens.pt")
# torch.save(chunk_indices, "/qwen-next/in/chunk_indices.pt")
# torch.save(chunk_offsets.to(torch.int32), "/qwen-next/in/chunk_offsets.pt")
# torch.save(chunk_size, "/qwen-next/in/chunk_size.pt")
# torch.save(output_final_state, "/qwen-next/in/output_final_state.pt")
h, v_new, final_state = torch.ops.xspeedgate_ops.chunk_gated_delta_rule_fwd_h(
k,
u,
w,
g,
initial_state,
cu_seqlens,
chunk_indices,
chunk_offsets.to(torch.int32),
chunk_size,
output_final_state,
True,
)
# h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
# k=k,
# w=w,
# u=u,
# g=g,
# initial_state=initial_state,
# output_final_state=output_final_state,
# cu_seqlens=cu_seqlens,
# )
# if not os.path.exists("/qwen-next/out"):
# os.makedirs("/qwen-next/out")
# torch.save(h, "/qwen-next/out/h.pt")
# torch.save(v_new, "/qwen-next/out/v_new.pt")
# torch.save(final_state, "/qwen-next/out/final_state.pt")
o = torch.ops.xspeedgate_ops.chunk_fwd_o(
2025-12-10 17:51:24 +08:00
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_size=64,
2025-12-10 17:51:24 +08:00
)
"""
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
"""
2025-12-10 17:51:24 +08:00
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
return g, o, A, final_state, w, h, v_new
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
use_qk_l2norm_in_kernel: bool = False,
):
2025-12-10 17:51:24 +08:00
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
return o.to(q.dtype), final_state
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):
2025-12-10 17:51:24 +08:00
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
v (torch.Tensor):
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
g (torch.Tensor):
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
beta (torch.Tensor):
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
head_first (Optional[bool]):
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
Default: `False`.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert q.dtype == k.dtype == v.dtype
assert (
q.dtype != torch.float32
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
assert (
len(beta.shape) == 3
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
2025-12-10 17:51:24 +08:00
if head_first:
raise DeprecationWarning(
"head_first is deprecated and will be removed in a future version. "
"Please use head_first=False for now instead.",
stacklevel=2,
)
2025-12-10 17:51:24 +08:00
q, k, v, beta, g = map(
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
)
2025-12-10 17:51:24 +08:00
if not head_first and q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
2025-12-10 17:51:24 +08:00
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
2025-12-10 17:51:24 +08:00
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
if False:
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
g = g.contiguous()
beta = beta.contiguous()
initial_state = initial_state.contiguous()
o = torch.empty_like(v)
final_state = torch.empty_like(initial_state)
import kunlun_ops
kunlun_ops.gated_delta_rule(
q,
k,
v,
initial_state,
g,
beta,
final_state,
o,
scale,
cu_seqlens.cpu(),
cu_seqlens,
cu_seqlens.cpu(),
cu_seqlens,
use_qk_l2norm_in_kernel=True,
)
else:
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
2025-12-10 17:51:24 +08:00
if head_first:
o = rearrange(o, "b t h ... -> b h t ...")
2025-12-10 17:51:24 +08:00
return o, final_state