155 lines
5.7 KiB
Python
155 lines
5.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
import xtorch_ops
|
|
|
|
|
|
class FusedRecurrentFunction(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
beta: torch.Tensor,
|
|
scale: float,
|
|
initial_state: torch.Tensor,
|
|
inplace_final_state: bool = True,
|
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
ssm_state_indices: Optional[torch.Tensor] = None,
|
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
|
use_qk_l2norm_in_kernel: bool = False):
|
|
|
|
o, final_state = xtorch_ops.fused_recurrent_gated_delta_rule_fwdv2(
|
|
q.contiguous(),
|
|
k.contiguous(),
|
|
v.contiguous(),
|
|
g.contiguous(),
|
|
beta.contiguous(),
|
|
scale,
|
|
initial_state,
|
|
inplace_final_state=inplace_final_state,
|
|
cu_seqlens=cu_seqlens,
|
|
h0_indices=ssm_state_indices,
|
|
num_accepted_tokens=num_accepted_tokens,
|
|
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
|
is_h0_transposed=True
|
|
)
|
|
return o, final_state
|
|
|
|
|
|
def fused_recurrent_gated_delta_rule(
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
g: torch.Tensor,
|
|
beta: torch.Tensor = None,
|
|
scale: float = None,
|
|
initial_state: torch.Tensor = None,
|
|
inplace_final_state: bool = True,
|
|
cu_seqlens: Optional[torch.LongTensor] = None,
|
|
ssm_state_indices: Optional[torch.Tensor] = None,
|
|
num_accepted_tokens: Optional[torch.Tensor] = None,
|
|
use_qk_l2norm_in_kernel: bool = False,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
r"""
|
|
Args:
|
|
q (torch.Tensor):
|
|
queries of shape `[B, T, H, K]`.
|
|
k (torch.Tensor):
|
|
keys of shape `[B, T, H, K]`.
|
|
v (torch.Tensor):
|
|
values of shape `[B, T, HV, V]`.
|
|
GVA is applied if `HV > H`.
|
|
g (torch.Tensor):
|
|
g (decays) of shape `[B, T, HV]`.
|
|
beta (torch.Tensor):
|
|
betas of shape `[B, T, HV]`.
|
|
scale (Optional[int]):
|
|
Scale factor for the RetNet attention scores.
|
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
|
initial_state (Optional[torch.Tensor]):
|
|
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
|
For equal-length input sequences, `N` equals the batch size `B`.
|
|
Default: `None`.
|
|
inplace_final_state: bool:
|
|
Whether to store the final state in-place to save memory.
|
|
Default: `True`.
|
|
cu_seqlens (torch.LongTensor):
|
|
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
|
consistent with the FlashAttention API.
|
|
ssm_state_indices (Optional[torch.Tensor]):
|
|
Indices to map the input sequences to the initial/final states.
|
|
num_accepted_tokens (Optional[torch.Tensor]):
|
|
Number of accepted tokens for each sequence during decoding.
|
|
|
|
Returns:
|
|
o (torch.Tensor):
|
|
Outputs of shape `[B, T, HV, V]`.
|
|
final_state (torch.Tensor):
|
|
Final state of shape `[N, HV, K, V]`.
|
|
|
|
Examples::
|
|
>>> import torch
|
|
>>> import torch.nn.functional as F
|
|
>>> from einops import rearrange
|
|
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
|
# inputs with equal lengths
|
|
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
|
>>> q = torch.randn(B, T, H, K, device='cuda')
|
|
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
|
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
|
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
|
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
|
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
|
>>> o, ht = fused_gated_recurrent_delta_rule(
|
|
q, k, v, g, beta,
|
|
initial_state=h0,
|
|
)
|
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
|
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
|
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
|
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
|
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
|
q, k, v, g, beta,
|
|
initial_state=h0,
|
|
cu_seqlens=cu_seqlens
|
|
)
|
|
"""
|
|
if cu_seqlens is not None and q.shape[0] != 1:
|
|
raise ValueError(
|
|
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
|
f"Please flatten variable-length inputs before processing.")
|
|
if scale is None:
|
|
scale = k.shape[-1]**-0.5
|
|
else:
|
|
assert scale > 0, "scale must be positive"
|
|
if beta is None:
|
|
beta = torch.ones_like(q[..., 0])
|
|
o, final_state = FusedRecurrentFunction.apply(
|
|
q,
|
|
k,
|
|
v,
|
|
g,
|
|
beta,
|
|
scale,
|
|
initial_state,
|
|
inplace_final_state,
|
|
cu_seqlens,
|
|
ssm_state_indices,
|
|
num_accepted_tokens,
|
|
use_qk_l2norm_in_kernel,
|
|
)
|
|
return o, final_state
|