Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -7,11 +7,17 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from .chunk import chunk_gated_delta_rule
from .fused_recurrent import fused_recurrent_gated_delta_rule
from .fused_recurrent import (
fused_recurrent_gated_delta_rule,
fused_recurrent_gated_delta_rule_packed_decode,
)
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from .layernorm_guard import RMSNormGated
__all__ = [
"RMSNormGated",
"chunk_gated_delta_rule",
"fused_recurrent_gated_delta_rule",
"fused_recurrent_gated_delta_rule_packed_decode",
"fused_sigmoid_gating_delta_rule_update",
]

View File

@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.

View File

@@ -89,7 +89,7 @@ def chunk_fwd_kernel_o(
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
@@ -145,7 +145,7 @@ def chunk_fwd_o(
h: torch.Tensor,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]

View File

@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):

View File

@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if state_idx < 0:
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
# Load state index and check for PAD_SLOT_ID (-1)
# Load state index and check for invalid entries
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0:
# Only store if state index is valid (not NULL_BLOCK_ID=0)
if final_state_idx > 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -252,6 +252,232 @@ def fused_recurrent_gated_delta_rule_fwd(
return o, final_state
@triton.jit
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
mixed_qkv,
a,
b,
A_log,
dt_bias,
o,
h0,
ht,
ssm_state_indices,
scale,
stride_mixed_qkv_tok: tl.constexpr,
stride_a_tok: tl.constexpr,
stride_b_tok: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
SOFTPLUS_THRESHOLD: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
o_k = tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
p_o = o + (i_n * HV + i_hv) * V + o_v
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
tl.store(p_o, zero, mask=mask_v)
return
p_h0 = h0 + state_idx * stride_init_state_token
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
q_off = i_h * K + o_k
k_off = (H * K) + i_h * K + o_k
v_off = (2 * H * K) + i_hv * V + o_v
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
A_log_val = tl.load(A_log + i_hv).to(tl.float32)
dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
x = a_val + dt_bias_val
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
g_val = -tl.exp(A_log_val) * softplus_x
beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
b_h *= exp(g_val)
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= beta_val
b_h += b_v[:, None] * b_k[None, :]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
p_ht = ht + state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
def fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
A_log: torch.Tensor,
dt_bias: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
out: torch.Tensor,
ssm_state_indices: torch.Tensor,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if mixed_qkv.ndim != 2:
raise ValueError(
f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
)
if mixed_qkv.stride(-1) != 1:
raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
if a.ndim != 2 or b.ndim != 2:
raise ValueError(
f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
)
if a.stride(-1) != 1 or b.stride(-1) != 1:
raise ValueError("`a`/`b` must be contiguous in the last dim.")
if A_log.ndim != 1 or dt_bias.ndim != 1:
raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
raise ValueError("`A_log`/`dt_bias` must be contiguous.")
if ssm_state_indices.ndim != 1:
raise ValueError(
f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
)
if not out.is_contiguous():
raise ValueError("`out` must be contiguous.")
dev = mixed_qkv.device
if (
a.device != dev
or b.device != dev
or A_log.device != dev
or dt_bias.device != dev
or initial_state.device != dev
or out.device != dev
or ssm_state_indices.device != dev
):
raise ValueError("All inputs must be on the same device.")
B = mixed_qkv.shape[0]
if a.shape[0] != B or b.shape[0] != B:
raise ValueError(
"Mismatched batch sizes: "
f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
)
if ssm_state_indices.shape[0] != B:
raise ValueError(
f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
)
if initial_state.ndim != 4:
raise ValueError(
f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
)
if initial_state.stride(-1) != 1:
raise ValueError("`initial_state` must be contiguous in the last dim.")
HV, V, K = initial_state.shape[-3:]
if a.shape[1] != HV or b.shape[1] != HV:
raise ValueError(
f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
)
if A_log.numel() != HV or dt_bias.numel() != HV:
raise ValueError(
f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
)
if out.shape != (B, 1, HV, V):
raise ValueError(
f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
)
qkv_dim = mixed_qkv.shape[1]
qk_dim = qkv_dim - HV * V
if qk_dim <= 0 or qk_dim % 2 != 0:
raise ValueError(
f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
)
q_dim = qk_dim // 2
if q_dim % K != 0:
raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
H = q_dim // K
if H <= 0 or HV % H != 0:
raise ValueError(
f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
)
BK = triton.next_power_of_2(K)
if triton.cdiv(K, BK) != 1:
raise ValueError(
f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
)
BV = min(triton.next_power_of_2(V), 32)
num_stages = 3
num_warps = 1
stride_mixed_qkv_tok = mixed_qkv.stride(0)
stride_a_tok = a.stride(0)
stride_b_tok = b.stride(0)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = initial_state.stride(0)
stride_indices_seq = ssm_state_indices.stride(0)
NV = triton.cdiv(V, BV)
grid = (NV, B * HV)
fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
mixed_qkv=mixed_qkv,
a=a,
b=b,
A_log=A_log,
dt_bias=dt_bias,
o=out,
h0=initial_state,
ht=initial_state,
ssm_state_indices=ssm_state_indices,
scale=scale,
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
stride_a_tok=stride_a_tok,
stride_b_tok=stride_b_tok,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
SOFTPLUS_THRESHOLD=20.0,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=num_warps,
num_stages=num_stages,
)
return out, initial_state
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(
@@ -264,7 +490,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -296,7 +522,7 @@ def fused_recurrent_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -324,7 +550,7 @@ def fused_recurrent_gated_delta_rule(
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
@@ -358,7 +584,7 @@ def fused_recurrent_gated_delta_rule(
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -0,0 +1,279 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import torch
from vllm.triton_utils import tl, triton
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
def fused_sigmoid_gating_delta_rule_update_kernel(
A_log,
a,
b,
dt_bias,
beta,
threshold,
q,
k,
v,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.int64, # num of sequences
T: tl.int64, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
if T == 0:
# no tokens to process for this sequence
return
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
p_A_log = A_log + i_hv
if not IS_KDA:
p_a = a + bos * HV + i_hv
p_dt_bias = dt_bias + i_hv
else:
p_a = a + (bos * HV + i_hv) * K + o_k
p_dt_bias = dt_bias + i_hv * K + o_k
p_b = b + bos * HV + i_hv
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
# Load state index and check for invalid entries
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
# Skip if state index is invalid (NULL_BLOCK_ID=0)
if state_idx <= 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
p_h0 = h0 + bos * HV * V * K
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
b_b = tl.load(p_b).to(tl.float32)
# If the model is loaded in fp16, without the .float() here, A might be -inf
x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32)
softplus_x = tl.where(
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
)
b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x
# compute beta_output = sigmoid(b)
b_beta = tl.sigmoid(b_b.to(tl.float32))
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6))
b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6))
b_q = b_q * scale
# [BV, BK]
if not IS_KDA:
b_h *= tl.exp(b_g)
else:
b_h *= tl.exp(b_g[None, :])
# [BV]
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= b_beta
# [BV, BK]
b_h += b_v[:, None] * b_k[None, :]
# [BV]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
# Load state index and check for invalid entries
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not NULL_BLOCK_ID=0)
if final_state_idx > 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
# Update pointers for next timestep
p_q += H * K
p_k += H * K
p_o += HV * V
p_v += HV * V
p_b += HV
p_a += HV
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
dt_bias: torch.Tensor,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
beta: float = 1.0,
threshold: float = 20.0,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
is_kda: bool = False,
):
"""
Fused triton implementation of sigmoid gating delta rule update.
This function uses a single fused kernel that combines both sigmoid gating
computation and the recurrent delta rule update for better performance.
"""
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 4
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]}"
f" when using `cu_seqlens`. Please flatten variable-length"
f" inputs before processing."
)
if scale is None:
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
o = q.new_empty(NK, *v.shape)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
grid = (NK, NV, N * HV)
fused_sigmoid_gating_delta_rule_update_kernel[grid](
A_log=A_log,
a=a.contiguous(),
b=b.contiguous(),
dt_bias=dt_bias,
beta=beta,
threshold=threshold,
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
INPLACE_FINAL_STATE=inplace_final_state,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
IS_KDA=is_kda,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, final_state

View File

@@ -15,14 +15,12 @@ from .utils import tensor_cache
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
indices = torch.cat(
[
torch.arange(n)
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
return torch.cat(
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
).cumsum(-1)

View File

@@ -37,7 +37,7 @@ def fused_recurrent_kda_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
@@ -115,7 +115,7 @@ def fused_recurrent_kda(
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
use_qk_l2norm_in_kernel: bool = True,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
ssm_state_indices: torch.LongTensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -692,7 +692,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
gk: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
@@ -706,7 +706,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
The beta tensor of shape `[B, T, H]`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
@@ -936,7 +936,7 @@ def recompute_w_u_fwd(
A: torch.Tensor,
q: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
BT = A.shape[-1]
@@ -1104,7 +1104,7 @@ def chunk_gla_fwd_o_gk(
h: torch.Tensor,
o: torch.Tensor,
scale: float,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
):
B, T, H, K, V = *q.shape, v.shape[-1]
@@ -1148,7 +1148,7 @@ def chunk_kda_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
):
chunk_size = 64
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
@@ -1208,7 +1208,7 @@ def chunk_kda(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.LongTensor | None = None,
cu_seqlens: torch.Tensor | None = None,
**kwargs,
):
if scale is None:

View File

@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
ACTIVATION: tl.constexpr,
):
# Map the program id to the starting row of X and Y it should compute.
row_start = tl.program_id(0) * ROWS_PER_BLOCK
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and not NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
x *= z * tl.sigmoid(z)
if ACTIVATION == "swish" or ACTIVATION == "silu":
x *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
x *= tl.sigmoid(z)
# Compute mean and variance per row (reduce along axis 1)
if not IS_RMS_NORM:
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
if HAS_Z and NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
y *= z * tl.sigmoid(z)
if ACTIVATION == "swish" or ACTIVATION == "silu":
y *= z * tl.sigmoid(z)
elif ACTIVATION == "sigmoid":
y *= tl.sigmoid(z)
# Write output
tl.store(Y_base, y, mask=mask)
@@ -178,6 +185,7 @@ def layer_norm_fwd(
group_size: int = None,
norm_before_gate: bool = True,
is_rms_norm: bool = False,
activation: str = "swish",
):
M, N = x.shape
if group_size is None:
@@ -232,9 +240,12 @@ def layer_norm_fwd(
eps,
BLOCK_N=BLOCK_N,
ROWS_PER_BLOCK=rows_per_block,
HAS_BIAS=bias is not None,
HAS_Z=z is not None,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
ACTIVATION=activation,
)
return out, mean, rstd
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
activation=activation,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
ctx.activation = activation
return y.reshape(x_shape_og)
@@ -296,17 +310,25 @@ def layernorm_fn(
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
activation: str = "swish",
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
activation: str = "swish",
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
)
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
activation: str = "swish",
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.activation = activation
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
activation=self.activation,
)

View File

@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.LongTensor | None,
cu_seqlens: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]