add flash linear attention triton kernel (#10239)
This commit is contained in:
242
python/sglang/srt/layers/attention/fla/chunk.py
Normal file
242
python/sglang/srt/layers/attention/fla/chunk.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/chunk.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from sglang.srt.layers.attention.fla.chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from sglang.srt.layers.attention.fla.chunk_o import chunk_fwd_o
|
||||
from sglang.srt.layers.attention.fla.chunk_scaled_dot_kkt import (
|
||||
chunk_scaled_dot_kkt_fwd,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.cumsum import chunk_local_cumsum
|
||||
from sglang.srt.layers.attention.fla.l2norm import l2norm_fwd
|
||||
from sglang.srt.layers.attention.fla.solve_tril import solve_tril
|
||||
from sglang.srt.layers.attention.fla.utils import (
|
||||
SUPPRESS_LEVEL,
|
||||
autocast_custom_fwd,
|
||||
input_guard,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(
|
||||
k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
||||
)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
return g, o, A, final_state, w, h, v_new
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@autocast_custom_fwd
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
q_orig = q
|
||||
k_orig = k
|
||||
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
head_first (Optional[bool]):
|
||||
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
||||
Default: `False`.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert (
|
||||
q.dtype != torch.float32
|
||||
), "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert (
|
||||
len(beta.shape) == 3
|
||||
), "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead."
|
||||
)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)
|
||||
)
|
||||
# if not head_first and q.shape[1] < q.shape[2]:
|
||||
# warnings.warn(
|
||||
# f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
# "This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
# "when head_first=False was specified. "
|
||||
# "Please verify your input tensor format matches the expected shape [B, T, H, ...]."
|
||||
# )
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
if head_first:
|
||||
o = rearrange(o, "b t h ... -> b h t ...")
|
||||
return o, final_state
|
||||
314
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
Normal file
314
python/sglang/srt/layers/attention/fla/chunk_delta_h.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_delta_h.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.index import (
|
||||
prepare_chunk_indices,
|
||||
prepare_chunk_offsets,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.op import exp, safe_exp
|
||||
from sglang.srt.layers.attention.fla.utils import is_nvidia_hopper
|
||||
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [2, 4]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# for BV in [32, 64]
|
||||
# ],
|
||||
# key=["H", "K", "V", "BT", "USE_G"],
|
||||
# use_cuda_graph=use_cuda_graph,
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# [BK, BV]
|
||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
|
||||
# calculate offset
|
||||
h += (boh * H + i_h) * K * V
|
||||
v += (bos * H + i_h) * V
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
w += (bos * H + i_h) * K
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new += (bos * H + i_h) * V
|
||||
stride_v = H * V
|
||||
stride_h = H * K * V
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
if USE_INITIAL_STATE:
|
||||
h0 = h0 + i_nh * K * V
|
||||
if STORE_FINAL_STATE:
|
||||
ht = ht + i_nh * K * V
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(
|
||||
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(
|
||||
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(
|
||||
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_v = tl.make_block_ptr(
|
||||
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_v_new = (
|
||||
tl.make_block_ptr(
|
||||
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
if SAVE_NEW_VALUE
|
||||
else None
|
||||
)
|
||||
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new = tl.make_block_ptr(
|
||||
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
tl.store(
|
||||
p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)
|
||||
)
|
||||
|
||||
if USE_G:
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(
|
||||
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_v_new = b_v_new * safe_exp(b_g_last - b_g)[:, None]
|
||||
b_g_last = exp(b_g_last)
|
||||
b_h1 = b_h1 * b_g_last
|
||||
if K > 64:
|
||||
b_h2 = b_h2 * b_g_last
|
||||
if K > 128:
|
||||
b_h3 = b_h3 * b_g_last
|
||||
if K > 192:
|
||||
b_h4 = b_h4 * b_g_last
|
||||
b_v_new = b_v_new.to(k.dtype.element_ty)
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h1 += tl.dot(b_k, b_v_new)
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h2 += tl.dot(b_k, b_v_new)
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h3 += tl.dot(b_k, b_v_new)
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h4 += tl.dot(b_k, b_v_new)
|
||||
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = (
|
||||
k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
)
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta["BV"]), N * H)
|
||||
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BV=32,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
return h, v_new, final_state
|
||||
178
python/sglang/srt/layers/attention/fla/chunk_o.py
Normal file
178
python/sglang/srt/layers/attention/fla/chunk_o.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_o.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
||||
from sglang.srt.layers.attention.fla.op import exp, safe_exp
|
||||
from sglang.srt.layers.attention.fla.utils import check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for BK in BKV_LIST
|
||||
# for BV in BKV_LIST
|
||||
# for num_warps in NUM_WARPS
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=["H", "K", "V", "BT"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
NT = tl.cdiv(T, BT)
|
||||
i_tg = i_b * NT + i_t
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
# offset calculation
|
||||
q += (bos * Hg + i_h // (H // Hg)) * K
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||
)
|
||||
p_h = tl.make_block_ptr(
|
||||
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
|
||||
)
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_o = b_o * exp(b_g)[:, None]
|
||||
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
m_A = o_i[:, None] >= o_i[None, :]
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(
|
||||
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by triton v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=128,
|
||||
BV=64,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
return o
|
||||
151
python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
Normal file
151
python/sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# Adapted from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/common/chunk_scaled_dot_kkt.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
||||
from sglang.srt.layers.attention.fla.op import safe_exp
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
"USE_G": lambda args: args["g_cumsum"] is not None,
|
||||
}
|
||||
)
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for BK in [32, 64, 128]
|
||||
# for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=["H", "K", "BT", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta,
|
||||
g_cumsum,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_t = tl.arange(0, BT)
|
||||
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K),
|
||||
(Hg * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(
|
||||
g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * safe_exp(b_g_diff)
|
||||
|
||||
b_A = tl.where(o_t[:, None] > o_t[None, :], b_A, 0)
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g_cumsum (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`.
|
||||
Default: None
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
|
||||
B, T, Hg, K = k.shape
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g_cumsum,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
BT=BT,
|
||||
BK=64,
|
||||
num_warps=8,
|
||||
num_stages=3,
|
||||
)
|
||||
return A
|
||||
300
python/sglang/srt/layers/attention/fla/cumsum.py
Normal file
300
python/sglang/srt/layers/attention/fla/cumsum.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/cumsum.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
||||
from sglang.srt.layers.attention.fla.utils import check_shared_mem, input_guard
|
||||
|
||||
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_SCALE": lambda args: args["scale"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
# @triton.autotune(
|
||||
# configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
|
||||
# key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
scale,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
HAS_SCALE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(
|
||||
s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
# [BT]
|
||||
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
|
||||
b_o = tl.cumsum(b_s, axis=0)
|
||||
if REVERSE:
|
||||
b_z = tl.sum(b_s, axis=0)
|
||||
b_o = -b_o + b_z[None] + b_s
|
||||
if HAS_SCALE:
|
||||
b_o *= scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_SCALE": lambda args: args["scale"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BS": BS}, num_warps=num_warps)
|
||||
for BS in BS_LIST
|
||||
for num_warps in [2, 4, 8]
|
||||
],
|
||||
key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_local_cumsum_vector_kernel(
|
||||
s,
|
||||
o,
|
||||
scale,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
HAS_SCALE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
if REVERSE:
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0)
|
||||
else:
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0)
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(
|
||||
s + (bos * H + i_h * T) * S,
|
||||
(T, S),
|
||||
(S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o + (bos * H + i_h * T) * S,
|
||||
(T, S),
|
||||
(S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
else:
|
||||
p_s = tl.make_block_ptr(
|
||||
s + (bos * H + i_h) * S,
|
||||
(T, S),
|
||||
(H * S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o + (bos * H + i_h) * S,
|
||||
(T, S),
|
||||
(H * S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
if HAS_SCALE:
|
||||
b_o *= scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2 ** (
|
||||
chunk_size.bit_length() - 1
|
||||
), "chunk_size must be a power of 2"
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B * H)
|
||||
chunk_local_cumsum_scalar_kernel[grid](
|
||||
s=g_org,
|
||||
o=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
num_warps=8,
|
||||
num_stages=3,
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum_vector(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T, S = g.shape
|
||||
else:
|
||||
B, T, H, S = g.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
assert chunk_size == 2 ** (
|
||||
chunk_size.bit_length() - 1
|
||||
), "chunk_size must be a power of 2"
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
|
||||
|
||||
# keep cumulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_local_cumsum_vector_kernel[grid](
|
||||
s=g_org,
|
||||
o=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
S=S,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
@input_guard
|
||||
def chunk_local_cumsum(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if cu_seqlens is not None:
|
||||
assert (
|
||||
g.shape[0] == 1
|
||||
), "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(
|
||||
g=g,
|
||||
chunk_size=chunk_size,
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
elif len(g.shape) == 4:
|
||||
return chunk_local_cumsum_vector(
|
||||
g=g,
|
||||
chunk_size=chunk_size,
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported input shape {g.shape}, "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise"
|
||||
)
|
||||
640
python/sglang/srt/layers/attention/fla/fused_recurrent.py
Normal file
640
python/sglang/srt/layers/attention/fla/fused_recurrent.py
Normal file
@@ -0,0 +1,640 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/fused_recurrent.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.op import exp
|
||||
from sglang.srt.layers.attention.fla.utils import input_guard
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
scale,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
STORE_FINAL_STATE: tl.constexpr, # whether to store final state
|
||||
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv
|
||||
p_g = g + bos * HV + i_hv
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0 = h0 + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
b_h *= exp(b_g)
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
p_g += HV
|
||||
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
||||
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = ht + i_nh * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
if output_final_state:
|
||||
final_state = q.new_empty(N, HV, K, V, dtype=torch.float32)
|
||||
else:
|
||||
final_state = None
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
o=o,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
scale=scale,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o, final_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
o, final_state = fused_recurrent_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
|
||||
return o, final_state
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
def backward(ctx, do, dht):
|
||||
raise NotImplementedError(
|
||||
"Backward pass is not implemented yet and we do not have plans to implement it "
|
||||
"because we haven't figured out how to compute dg without materializing the full "
|
||||
"hidden states for all time steps."
|
||||
)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, HV, V]`.
|
||||
GVA is applied if `HV > H`.
|
||||
g (torch.Tensor):
|
||||
g (decays) of shape `[B, T, HV]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, HV]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, HV, K, V]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]` if `output_final_state=True` else `None`.
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
"CACHE_INTERMEDIATE_STATES": lambda args: args["intermediate_states_buffer"]
|
||||
is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def fused_recurrent_gated_delta_rule_update_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0_source,
|
||||
h0_indices,
|
||||
cu_seqlens,
|
||||
scale,
|
||||
intermediate_states_buffer,
|
||||
cache_steps,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
DISABLE_STATE_UPDATE: tl.constexpr, # whether to disable final state update
|
||||
DISABLE_OUTPUT_CALCULATION: tl.constexpr, # whether to disable output calculation
|
||||
CACHE_INTERMEDIATE_STATES: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv
|
||||
p_g = g + bos * HV + i_hv
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
idx = tl.load(h0_indices + i_n)
|
||||
# Add bounds checking for idx
|
||||
if idx >= 0: # Assuming negative indices are invalid
|
||||
p_h0 = (
|
||||
h0_source
|
||||
+ idx * HV * K * V
|
||||
+ i_hv * K * V
|
||||
+ o_k[:, None] * V
|
||||
+ o_v[None, :]
|
||||
)
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
# Prepare intermediate state cache variables if enabled
|
||||
cache_idx = -1
|
||||
if CACHE_INTERMEDIATE_STATES:
|
||||
cache_idx = tl.load(h0_indices + i_n)
|
||||
|
||||
step_idx = 0
|
||||
for _ in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
b_h *= exp(b_g)
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV]
|
||||
if not DISABLE_OUTPUT_CALCULATION:
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
# core attn output
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# store intermediate states if enabled
|
||||
if CACHE_INTERMEDIATE_STATES:
|
||||
if cache_idx >= 0:
|
||||
# Compute cache pointer for this step
|
||||
step_offset = step_idx * HV * K * V
|
||||
cache_ptr = (
|
||||
intermediate_states_buffer
|
||||
+ cache_idx * cache_steps * HV * K * V
|
||||
+ step_offset
|
||||
+ i_hv * K * V
|
||||
+ o_k[:, None] * V
|
||||
+ o_v[None, :]
|
||||
)
|
||||
tl.store(cache_ptr, b_h.to(cache_ptr.dtype.element_ty), mask=mask_h)
|
||||
|
||||
step_idx += 1
|
||||
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
p_g += HV
|
||||
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
||||
|
||||
# Store final state back to h0_source with bounds checking
|
||||
# ssm states
|
||||
if not DISABLE_STATE_UPDATE:
|
||||
idx = tl.load(h0_indices + i_n)
|
||||
if idx >= 0: # Add bounds checking
|
||||
p_h0 = (
|
||||
h0_source
|
||||
+ idx * HV * K * V
|
||||
+ i_hv * K * V
|
||||
+ o_k[:, None] * V
|
||||
+ o_v[None, :]
|
||||
)
|
||||
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_update_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state_source: torch.Tensor,
|
||||
initial_state_indices: torch.Tensor,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
disable_state_update: bool = False,
|
||||
disable_output_calculation: bool = False,
|
||||
intermediate_states_buffer: Optional[torch.Tensor] = None,
|
||||
cache_steps: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
if disable_output_calculation:
|
||||
# When output calculation is disabled, allocate minimal tensor
|
||||
o = q.new_empty(NK, 1, 1, 1, 1) # minimal allocation
|
||||
else:
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
|
||||
fused_recurrent_gated_delta_rule_update_fwd_kernel[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
o=o,
|
||||
h0_source=initial_state_source,
|
||||
h0_indices=initial_state_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
scale=scale,
|
||||
intermediate_states_buffer=intermediate_states_buffer,
|
||||
cache_steps=0 if cache_steps is None else cache_steps,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
DISABLE_STATE_UPDATE=disable_state_update,
|
||||
DISABLE_OUTPUT_CALCULATION=disable_output_calculation,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o
|
||||
|
||||
|
||||
class FusedRecurrentUpdateFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state_source: torch.Tensor,
|
||||
initial_state_indices: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
disable_state_update: bool = False,
|
||||
disable_output_calculation: bool = False,
|
||||
intermediate_states_buffer: Optional[torch.Tensor] = None,
|
||||
cache_steps: Optional[int] = None,
|
||||
):
|
||||
o = fused_recurrent_gated_delta_rule_update_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state_source=initial_state_source,
|
||||
initial_state_indices=initial_state_indices,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
cu_seqlens=cu_seqlens,
|
||||
disable_state_update=disable_state_update,
|
||||
disable_output_calculation=disable_output_calculation,
|
||||
intermediate_states_buffer=intermediate_states_buffer,
|
||||
cache_steps=cache_steps,
|
||||
)
|
||||
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
def backward(ctx, do, dht):
|
||||
raise NotImplementedError(
|
||||
"Backward pass is not implemented yet and we do not have plans to implement it "
|
||||
"because we haven't figured out how to compute dg without materializing the full "
|
||||
"hidden states for all time steps."
|
||||
)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_update(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state_source: torch.Tensor = None,
|
||||
initial_state_indices: torch.Tensor = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
disable_state_update: bool = False,
|
||||
disable_output_calculation: bool = False,
|
||||
intermediate_states_buffer: Optional[torch.Tensor] = None,
|
||||
cache_steps: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if (
|
||||
initial_state_source is not None
|
||||
and initial_state_indices.shape[0] != len(cu_seqlens) - 1
|
||||
):
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state_indices.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o = FusedRecurrentUpdateFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state_source,
|
||||
initial_state_indices,
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
disable_state_update,
|
||||
disable_output_calculation,
|
||||
intermediate_states_buffer,
|
||||
cache_steps,
|
||||
)
|
||||
return o
|
||||
@@ -0,0 +1,232 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.utils import input_guard
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0_source"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def fused_sigmoid_gating_delta_rule_update_kernel(
|
||||
A_log,
|
||||
a,
|
||||
dt_bias,
|
||||
softplus_beta,
|
||||
softplus_threshold,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
b,
|
||||
o,
|
||||
h0_source,
|
||||
h0_indices,
|
||||
cu_seqlens,
|
||||
scale,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused kernel that combines sigmoid gating computation with recurrent delta rule update.
|
||||
"""
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
|
||||
if IS_VARLEN:
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int64),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
|
||||
)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
p_b = b + bos * HV + i_hv
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
# Gating computation pointers
|
||||
p_A_log = A_log + i_hv
|
||||
p_a = a + bos * HV + i_hv
|
||||
p_dt_bias = dt_bias + i_hv
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
idx = tl.load(h0_indices + i_n)
|
||||
if idx >= 0:
|
||||
p_h0 = (
|
||||
h0_source
|
||||
+ idx * HV * K * V
|
||||
+ i_hv * K * V
|
||||
+ o_k[:, None] * V
|
||||
+ o_v[None, :]
|
||||
)
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for _ in range(0, T):
|
||||
# Load inputs
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_b = tl.load(p_b).to(tl.float32)
|
||||
|
||||
# Compute sigmoid gating
|
||||
# Load gating parameters
|
||||
b_A_log = tl.load(p_A_log).to(tl.float32)
|
||||
b_a = tl.load(p_a).to(tl.float32)
|
||||
b_dt_bias = tl.load(p_dt_bias).to(tl.float32)
|
||||
|
||||
# Compute g = -exp(A_log) * softplus(a + dt_bias)
|
||||
x = b_a + b_dt_bias
|
||||
beta_x = softplus_beta * x
|
||||
# Apply softplus with numerical stability
|
||||
softplus_x = tl.where(
|
||||
beta_x <= softplus_threshold,
|
||||
(1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
|
||||
x,
|
||||
)
|
||||
b_g = -tl.exp(b_A_log) * softplus_x
|
||||
|
||||
# Compute beta = sigmoid(b)
|
||||
b_beta = 1.0 / (1.0 + tl.exp(-b_b))
|
||||
|
||||
# Apply L2 normalization if enabled
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
|
||||
b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
|
||||
|
||||
b_q = b_q * scale
|
||||
|
||||
# Apply gating to hidden state: h *= exp(g)
|
||||
b_h *= tl.exp(b_g)
|
||||
|
||||
# Delta rule: v -= sum(h * k, dim=0)
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
|
||||
# Apply beta gating: v *= beta
|
||||
b_v *= b_beta
|
||||
|
||||
# Update hidden state: h += k[:, None] * v[None, :]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
|
||||
# Compute output: o = sum(h * q, dim=0)
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# Update pointers for next timestep
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
p_b += HV
|
||||
p_a += HV
|
||||
|
||||
# Store final state back to h0_source with bounds checking
|
||||
if USE_INITIAL_STATE:
|
||||
idx = tl.load(h0_indices + i_n)
|
||||
if idx >= 0:
|
||||
p_h0 = (
|
||||
h0_source
|
||||
+ idx * HV * K * V
|
||||
+ i_hv * K * V
|
||||
+ o_k[:, None] * V
|
||||
+ o_v[None, :]
|
||||
)
|
||||
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
@input_guard
|
||||
def fused_sigmoid_gating_delta_rule_update(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
softplus_beta: float,
|
||||
softplus_threshold: float,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
initial_state_source: torch.Tensor,
|
||||
initial_state_indices: torch.Tensor,
|
||||
scale: Optional[float] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Fused triton implementation of sigmoid gating delta rule update.
|
||||
This function uses a single fused kernel that combines both sigmoid gating computation
|
||||
and the recurrent delta rule update for better performance.
|
||||
"""
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
grid = (NK, NV, N * HV)
|
||||
|
||||
fused_sigmoid_gating_delta_rule_update_kernel[grid](
|
||||
A_log=A_log,
|
||||
a=a,
|
||||
dt_bias=dt_bias,
|
||||
softplus_beta=softplus_beta,
|
||||
softplus_threshold=softplus_threshold,
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
b=b,
|
||||
o=o,
|
||||
h0_source=initial_state_source,
|
||||
h0_indices=initial_state_indices,
|
||||
cu_seqlens=cu_seqlens,
|
||||
scale=scale,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o
|
||||
37
python/sglang/srt/layers/attention/fla/index.py
Normal file
37
python/sglang/srt/layers/attention/fla/index.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/index.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
]
|
||||
)
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
return torch.cat(
|
||||
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
|
||||
).cumsum(-1)
|
||||
150
python/sglang/srt/layers/attention/fla/l2norm.py
Normal file
150
python/sglang/srt/layers/attention/fla/l2norm.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/l2norm.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.utils import input_guard
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
# ],
|
||||
# key=["D"],
|
||||
# )
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
y,
|
||||
D,
|
||||
BD: tl.constexpr,
|
||||
eps,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
x += i_t * D
|
||||
y += i_t * D
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BD)
|
||||
mask = cols < D
|
||||
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=0)
|
||||
b_rstd = 1 / tl.sqrt(b_var + eps)
|
||||
# tl.store(Rstd + i_t, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
b_y = b_x * b_rstd
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BT": BT}, num_warps=num_warps)
|
||||
# for num_warps in [1, 2, 4, 8, 16]
|
||||
# for BT in BT_LIST
|
||||
# ],
|
||||
# key=["D", "NB"],
|
||||
# )
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB: tl.constexpr,
|
||||
T: tl.constexpr,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=1)
|
||||
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
|
||||
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def l2norm_fwd(
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta["BT"]),)
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
BT=16,
|
||||
num_warps=8,
|
||||
num_stages=3,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T,)](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
num_warps=8,
|
||||
num_stages=3,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
|
||||
|
||||
class L2NormFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
def forward(ctx, x, eps=1e-6, output_dtype=None):
|
||||
return l2norm_fwd(x, eps, output_dtype)
|
||||
|
||||
|
||||
def l2norm(
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
|
||||
) -> torch.Tensor:
|
||||
return L2NormFunction.apply(x, eps, output_dtype)
|
||||
|
||||
|
||||
l2_norm = l2norm
|
||||
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
|
||||
def __init__(self, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.output_dtype = output_dtype
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return l2norm(x, self.eps, self.output_dtype)
|
||||
326
python/sglang/srt/layers/attention/fla/layernorm_gated.py
Normal file
326
python/sglang/srt/layers/attention/fla/layernorm_gated.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/modules/layernorm_gated.py
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def rms_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True,
|
||||
):
|
||||
dtype = x.dtype
|
||||
N = x.shape[-1]
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps=1e-5,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return rmsnorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
66
python/sglang/srt/layers/attention/fla/op.py
Normal file
66
python/sglang/srt/layers/attention/fla/op.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/op.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import triton.language.extra.libdevice as tldevice
|
||||
|
||||
from sglang.srt.layers.attention.fla.utils import is_gather_supported
|
||||
|
||||
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
||||
exp = tldevice.fast_expf
|
||||
exp2 = tldevice.exp2
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
exp = tl.exp
|
||||
exp2 = tl.math.exp2
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def safe_exp(x):
|
||||
return exp(tl.where(x <= 0, x, float("-inf")))
|
||||
|
||||
|
||||
if not is_gather_supported:
|
||||
|
||||
@triton.jit
|
||||
def gather(src, index, axis, _builder=None):
|
||||
"""
|
||||
Gather operation that works when tl.gather is not supported.
|
||||
This is a fallback implementation that returns None.
|
||||
Just to make triton compiler happy.
|
||||
"""
|
||||
return None
|
||||
|
||||
else:
|
||||
gather = tl.gather
|
||||
|
||||
|
||||
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
|
||||
# For Triton 3.3.x
|
||||
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
|
||||
elif hasattr(triton.language, "make_tensor_descriptor"):
|
||||
# For Triton 3.4.x and later
|
||||
make_tensor_descriptor = triton.language.make_tensor_descriptor
|
||||
else:
|
||||
"""
|
||||
Fallback implementation when TMA is not supported.
|
||||
Returns None to indicate TMA descriptors are unavailable.
|
||||
Just make triton compiler happy.
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def make_tensor_descriptor(
|
||||
base,
|
||||
shape,
|
||||
strides,
|
||||
block_shape,
|
||||
_builder=None,
|
||||
):
|
||||
return None
|
||||
465
python/sglang/srt/layers/attention/fla/solve_tril.py
Normal file
465
python/sglang/srt/layers/attention/fla/solve_tril.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/utils/solve_tril.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
||||
from sglang.srt.layers.attention.fla.utils import input_guard
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [1, 2, 4, 8]
|
||||
# for num_stages in [2, 3, 4, 5]
|
||||
# ],
|
||||
# key=["BT"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ad = Ad + (bos * H + i_h) * 16
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
p_A = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
for i in range(1, min(16, T - i_t * 16)):
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
mask = o_i == i
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
b_A += o_i[:, None] == o_i[None, :]
|
||||
tl.store(
|
||||
p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [1, 2, 4, 8]
|
||||
# for num_stages in [2, 3, 4, 5]
|
||||
# ],
|
||||
# key=["H", "BT", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def merge_16x16_to_32x32_inverse_kernel(
|
||||
A,
|
||||
Ad,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A += (bos * H + i_h) * 32
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 32
|
||||
|
||||
p_A_21 = tl.make_block_ptr(
|
||||
A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ad_11 = tl.make_block_ptr(
|
||||
Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ad_22 = tl.make_block_ptr(
|
||||
Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_11 = tl.make_block_ptr(
|
||||
Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_22 = tl.make_block_ptr(
|
||||
Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_21 = tl.make_block_ptr(
|
||||
Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_21 = -tl.dot(
|
||||
tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee"
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_11,
|
||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_22,
|
||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_21,
|
||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4, 5]
|
||||
# ],
|
||||
# key=["H", "BT", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def merge_16x16_to_64x64_inverse_kernel(
|
||||
A,
|
||||
Ad,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A += (bos * H + i_h) * 64
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 64
|
||||
|
||||
p_A_21 = tl.make_block_ptr(
|
||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_32 = tl.make_block_ptr(
|
||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_31 = tl.make_block_ptr(
|
||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_43 = tl.make_block_ptr(
|
||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_42 = tl.make_block_ptr(
|
||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_41 = tl.make_block_ptr(
|
||||
A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ad_11 = tl.make_block_ptr(
|
||||
Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ad_22 = tl.make_block_ptr(
|
||||
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ad_33 = tl.make_block_ptr(
|
||||
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ad_44 = tl.make_block_ptr(
|
||||
Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
|
||||
)
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
Ai_21 = -tl.dot(
|
||||
tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee"
|
||||
)
|
||||
Ai_32 = -tl.dot(
|
||||
tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee"
|
||||
)
|
||||
Ai_43 = -tl.dot(
|
||||
tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee"
|
||||
)
|
||||
|
||||
Ai_31 = -tl.dot(
|
||||
Ai_33,
|
||||
tl.dot(A_31, Ai_11, input_precision="ieee")
|
||||
+ tl.dot(A_32, Ai_21, input_precision="ieee"),
|
||||
input_precision="ieee",
|
||||
)
|
||||
Ai_42 = -tl.dot(
|
||||
Ai_44,
|
||||
tl.dot(A_42, Ai_22, input_precision="ieee")
|
||||
+ tl.dot(A_43, Ai_32, input_precision="ieee"),
|
||||
input_precision="ieee",
|
||||
)
|
||||
Ai_41 = -tl.dot(
|
||||
Ai_44,
|
||||
tl.dot(A_41, Ai_11, input_precision="ieee")
|
||||
+ tl.dot(A_42, Ai_21, input_precision="ieee")
|
||||
+ tl.dot(A_43, Ai_31, input_precision="ieee"),
|
||||
input_precision="ieee",
|
||||
)
|
||||
|
||||
p_Ai_11 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_22 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_33 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_44 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_21 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_31 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_32 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_41 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_42 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_43 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_11,
|
||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_22,
|
||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_33,
|
||||
Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_44,
|
||||
Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_21,
|
||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_31,
|
||||
Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_32,
|
||||
Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_41,
|
||||
Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_42,
|
||||
Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_43,
|
||||
Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
|
||||
fill_zeros = tl.zeros((16, 16), dtype=tl.float32)
|
||||
p_Ai_12 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_13 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_14 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_23 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_24 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_34 = tl.make_block_ptr(
|
||||
Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_12,
|
||||
fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_13,
|
||||
fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_14,
|
||||
fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_23,
|
||||
fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_24,
|
||||
fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_34,
|
||||
fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@input_guard
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse of the lower triangular matrix
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, K]
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
"""
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
Ad = torch.empty(
|
||||
B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype
|
||||
)
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
num_warps=1,
|
||||
num_stages=4,
|
||||
)
|
||||
if BT == 16:
|
||||
return Ad
|
||||
|
||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
||||
merge_fn = (
|
||||
merge_16x16_to_32x32_inverse_kernel
|
||||
if BT == 32
|
||||
else merge_16x16_to_64x64_inverse_kernel
|
||||
)
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
merge_fn[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
Ai=Ai,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
)
|
||||
return Ai
|
||||
331
python/sglang/srt/layers/attention/fla/utils.py
Normal file
331
python/sglang/srt/layers/attention/fla/utils.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/utils.py
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from packaging import version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
||||
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def check_environments():
|
||||
"""
|
||||
Checks the current operating system, Triton version, and Python version,
|
||||
issuing warnings if they don't meet recommendations.
|
||||
This function's body only runs once due to lru_cache.
|
||||
"""
|
||||
# Check Operating System
|
||||
if sys.platform == "win32":
|
||||
logger.warning(
|
||||
"Detected Windows operating system. Triton does not have an official Windows release, "
|
||||
"thus FLA will not be adapted for Windows, and any potential errors will not be fixed. "
|
||||
"Please consider using a Linux environment for compatibility."
|
||||
)
|
||||
|
||||
triton_version = version.parse(triton.__version__)
|
||||
required_triton_version = version.parse("3.2.0")
|
||||
|
||||
if triton_version < required_triton_version:
|
||||
logger.warning(
|
||||
f"Current Triton version {triton_version} is below the recommended 3.2.0 version. "
|
||||
"Errors may occur and these issues will not be fixed. "
|
||||
"Please consider upgrading Triton."
|
||||
)
|
||||
|
||||
# Check Python version
|
||||
py_version = version.parse(f"{sys.version_info.major}.{sys.version_info.minor}")
|
||||
required_py_version = version.parse("3.11")
|
||||
|
||||
if py_version < required_py_version:
|
||||
logger.warning(
|
||||
f"Current Python version {py_version} is below the recommended 3.11 version. "
|
||||
"It is recommended to upgrade to Python 3.11 or higher for the best experience."
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
check_environments()
|
||||
|
||||
|
||||
def get_abs_err(x, y):
|
||||
return (x.detach() - y.detach()).flatten().abs().max().item()
|
||||
|
||||
|
||||
def get_err_ratio(x, y):
|
||||
err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item()
|
||||
base = (x.detach()).flatten().square().mean().sqrt().item()
|
||||
return err / (base + 1e-8)
|
||||
|
||||
|
||||
def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6):
|
||||
abs_atol = get_abs_err(ref, tri)
|
||||
msg = f"{prefix} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}"
|
||||
logger.info(msg)
|
||||
error_rate = get_err_ratio(ref, tri)
|
||||
if abs_atol <= err_atol:
|
||||
return
|
||||
if warning or (FLA_CI_ENV and (error_rate < 0.01 or abs_atol <= 0.3)):
|
||||
if error_rate > ratio:
|
||||
import warnings
|
||||
|
||||
warnings.warn(msg)
|
||||
else:
|
||||
assert error_rate < ratio, msg
|
||||
|
||||
|
||||
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
||||
|
||||
|
||||
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
|
||||
cache_size = 4
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache_entries, cache_size
|
||||
for i, entry in enumerate(cache_entries):
|
||||
last_args, last_kwargs, last_result = entry
|
||||
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs):
|
||||
if all(a is b for a, b in zip(args, last_args)) and all(
|
||||
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
||||
):
|
||||
cache_entries = (
|
||||
cache_entries[:i]
|
||||
+ cache_entries[i + 1 :]
|
||||
+ [(args, kwargs, last_result)]
|
||||
)
|
||||
return last_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if len(cache_entries) >= cache_size:
|
||||
cache_entries = cache_entries[1:]
|
||||
cache_entries.append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (
|
||||
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
||||
)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = custom_device_ctx(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
contiguous = input_guard
|
||||
|
||||
|
||||
def require_version(version, hint):
|
||||
"""
|
||||
Perform a runtime check of the dependency versions, using the exact same syntax used by pip.
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(ctx, *args, **kwargs):
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
require_version(version, hint)
|
||||
return fn(
|
||||
ctx,
|
||||
*(
|
||||
i if not isinstance(i, torch.Tensor) else i.contiguous()
|
||||
for i in args
|
||||
),
|
||||
**{
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
},
|
||||
)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def checkpoint(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
return torch.utils.checkpoint.checkpoint(fn, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def check_pytorch_version(version_s: str = "2.4") -> bool:
|
||||
return version.parse(torch.__version__) >= version.parse(version_s)
|
||||
|
||||
|
||||
def _cpu_device_warning():
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
("Triton is not supported on current platform, roll back to CPU."), stacklevel=1
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_multiprocessor_count(tensor_idx: int = 0) -> int:
|
||||
try:
|
||||
return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)[
|
||||
"multiprocessor_count"
|
||||
]
|
||||
except BaseException:
|
||||
_cpu_device_warning()
|
||||
return -1
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_available_device() -> str:
|
||||
try:
|
||||
return triton.runtime.driver.active.get_current_target().backend
|
||||
except BaseException:
|
||||
_cpu_device_warning()
|
||||
return "cpu"
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
||||
device = get_available_device()
|
||||
if device == "cuda":
|
||||
return "nvidia"
|
||||
elif device == "hip":
|
||||
return "amd"
|
||||
elif device == "xpu":
|
||||
return "intel"
|
||||
else:
|
||||
return device
|
||||
|
||||
|
||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||
device = get_available_device() if get_available_device() != "hip" else "cuda"
|
||||
device_torch_lib = getattr(torch, device)
|
||||
device_platform = _check_platform()
|
||||
|
||||
is_amd = device_platform == "amd"
|
||||
is_intel = device_platform == "intel"
|
||||
is_nvidia = device_platform == "nvidia"
|
||||
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
||||
is_nvidia_hopper = is_nvidia and (
|
||||
"NVIDIA H" in torch.cuda.get_device_name(0)
|
||||
or torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
||||
|
||||
# Nvidia Ampere or newer, haven't check AMD and intel yet.
|
||||
is_tf32_supported = is_nvidia and torch.cuda.get_device_capability(0)[0] >= 8
|
||||
is_gather_supported = hasattr(triton.language, "gather")
|
||||
|
||||
|
||||
def get_all_max_shared_mem():
|
||||
try:
|
||||
return [
|
||||
triton.runtime.driver.active.utils.get_device_properties(i)[
|
||||
"max_shared_mem"
|
||||
]
|
||||
for i in range(device_torch_lib.device_count())
|
||||
]
|
||||
except BaseException:
|
||||
_cpu_device_warning()
|
||||
return [-1]
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ADA = 101376 # RTX 4090
|
||||
AMPERE = 166912 # A100
|
||||
HOPPER = 232448 # H100
|
||||
DEFAULT = 102400 # Default
|
||||
|
||||
@classmethod
|
||||
def get_shared_memory(cls, arch: str) -> int:
|
||||
try:
|
||||
return cls[arch.upper()].value
|
||||
except KeyError:
|
||||
return cls.DEFAULT.value
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||
try:
|
||||
device_shared_mem_list = get_all_max_shared_mem()
|
||||
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
if check_pytorch_version("2.4"):
|
||||
device = "cuda" if device == "cpu" else device
|
||||
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device)
|
||||
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device)
|
||||
|
||||
def custom_device_ctx(index: int):
|
||||
return device_torch_lib.device(index)
|
||||
|
||||
else:
|
||||
assert (
|
||||
device == "cuda"
|
||||
), "Only cuda device is supported for PyTorch version < 2.4.0."
|
||||
autocast_custom_fwd = device_torch_lib.amp.custom_fwd
|
||||
autocast_custom_bwd = device_torch_lib.amp.custom_bwd
|
||||
|
||||
def custom_device_ctx(index: int):
|
||||
return torch.cuda.device(index)
|
||||
158
python/sglang/srt/layers/attention/fla/wy_fast.py
Normal file
158
python/sglang/srt/layers/attention/fla/wy_fast.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Adapt from https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/gated_delta_rule/wy_fast.py
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.fla.index import prepare_chunk_indices
|
||||
from sglang.srt.layers.attention.fla.op import safe_exp
|
||||
from sglang.srt.layers.attention.fla.utils import check_shared_mem
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
# for num_warps in [2, 4, 8]
|
||||
# for num_stages in [2, 3, 4]
|
||||
# ],
|
||||
# key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
||||
# )
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_w_u_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(
|
||||
chunk_indices + i_t * 2 + 1
|
||||
).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(
|
||||
cu_seqlens + i_n + 1
|
||||
).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
p_u = tl.make_block_ptr(
|
||||
u + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K),
|
||||
(Hg * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
p_w = tl.make_block_ptr(
|
||||
w + (bos * H + i_h) * K,
|
||||
(T, K),
|
||||
(H * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_w_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
)
|
||||
return w, u
|
||||
|
||||
|
||||
fwd_recompute_w_u = recompute_w_u_fwd
|
||||
Reference in New Issue
Block a user