Upgrade to vllm 0.17.0 corex v4.1 overlay
This commit is contained in:
@@ -7,11 +7,17 @@
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .fused_recurrent import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
)
|
||||
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
|
||||
from .layernorm_guard import RMSNormGated
|
||||
|
||||
__all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule_packed_decode",
|
||||
"fused_sigmoid_gating_delta_rule_update",
|
||||
]
|
||||
|
||||
@@ -30,7 +30,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
@@ -84,7 +84,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -117,7 +117,7 @@ def chunk_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
@@ -141,7 +141,7 @@ def chunk_gated_delta_rule(
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
Returns:
|
||||
@@ -171,7 +171,7 @@ def chunk_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
@@ -288,7 +288,7 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
|
||||
@@ -89,7 +89,7 @@ def chunk_fwd_kernel_o(
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
@@ -145,7 +145,7 @@ def chunk_fwd_o(
|
||||
h: torch.Tensor,
|
||||
g: torch.Tensor | None = None, # cumsum of log decay
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
|
||||
@@ -102,7 +102,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
g: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
@@ -116,7 +116,7 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
|
||||
@@ -106,12 +106,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
# Load state index and check for invalid entries
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
tl.int64
|
||||
)
|
||||
# Skip if state index is invalid (PAD_SLOT_ID = -1)
|
||||
if state_idx < 0:
|
||||
# Skip if state index is invalid (NULL_BLOCK_ID=0)
|
||||
if state_idx <= 0:
|
||||
return
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
else:
|
||||
@@ -150,12 +150,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
# Load state index and check for invalid entries
|
||||
final_state_idx = tl.load(
|
||||
ssm_state_indices + i_n * stride_indices_seq + i_t
|
||||
).to(tl.int64)
|
||||
# Only store if state index is valid (not PAD_SLOT_ID)
|
||||
if final_state_idx >= 0:
|
||||
# Only store if state index is valid (not NULL_BLOCK_ID=0)
|
||||
if final_state_idx > 0:
|
||||
p_ht = ht + final_state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
@@ -184,7 +184,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -252,6 +252,232 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
return o, final_state
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
|
||||
mixed_qkv,
|
||||
a,
|
||||
b,
|
||||
A_log,
|
||||
dt_bias,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
ssm_state_indices,
|
||||
scale,
|
||||
stride_mixed_qkv_tok: tl.constexpr,
|
||||
stride_a_tok: tl.constexpr,
|
||||
stride_b_tok: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
SOFTPLUS_THRESHOLD: tl.constexpr,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
|
||||
o_k = tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
|
||||
p_o = o + (i_n * HV + i_hv) * V + o_v
|
||||
|
||||
# Skip if state index is invalid (NULL_BLOCK_ID=0)
|
||||
if state_idx <= 0:
|
||||
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
|
||||
tl.store(p_o, zero, mask=mask_v)
|
||||
return
|
||||
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
|
||||
q_off = i_h * K + o_k
|
||||
k_off = (H * K) + i_h * K + o_k
|
||||
v_off = (2 * H * K) + i_hv * V + o_v
|
||||
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
|
||||
a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
|
||||
b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
|
||||
A_log_val = tl.load(A_log + i_hv).to(tl.float32)
|
||||
dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
|
||||
x = a_val + dt_bias_val
|
||||
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
|
||||
g_val = -tl.exp(A_log_val) * softplus_x
|
||||
beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
|
||||
|
||||
b_h *= exp(g_val)
|
||||
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||
b_v *= beta_val
|
||||
b_h += b_v[:, None] * b_k[None, :]
|
||||
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
p_ht = ht + state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
A_log: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
ssm_state_indices: torch.Tensor,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if mixed_qkv.ndim != 2:
|
||||
raise ValueError(
|
||||
f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
|
||||
)
|
||||
if mixed_qkv.stride(-1) != 1:
|
||||
raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
|
||||
if a.ndim != 2 or b.ndim != 2:
|
||||
raise ValueError(
|
||||
f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
|
||||
)
|
||||
if a.stride(-1) != 1 or b.stride(-1) != 1:
|
||||
raise ValueError("`a`/`b` must be contiguous in the last dim.")
|
||||
if A_log.ndim != 1 or dt_bias.ndim != 1:
|
||||
raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
|
||||
if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
|
||||
raise ValueError("`A_log`/`dt_bias` must be contiguous.")
|
||||
if ssm_state_indices.ndim != 1:
|
||||
raise ValueError(
|
||||
f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
|
||||
)
|
||||
if not out.is_contiguous():
|
||||
raise ValueError("`out` must be contiguous.")
|
||||
|
||||
dev = mixed_qkv.device
|
||||
if (
|
||||
a.device != dev
|
||||
or b.device != dev
|
||||
or A_log.device != dev
|
||||
or dt_bias.device != dev
|
||||
or initial_state.device != dev
|
||||
or out.device != dev
|
||||
or ssm_state_indices.device != dev
|
||||
):
|
||||
raise ValueError("All inputs must be on the same device.")
|
||||
|
||||
B = mixed_qkv.shape[0]
|
||||
if a.shape[0] != B or b.shape[0] != B:
|
||||
raise ValueError(
|
||||
"Mismatched batch sizes: "
|
||||
f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
|
||||
)
|
||||
if ssm_state_indices.shape[0] != B:
|
||||
raise ValueError(
|
||||
f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
|
||||
)
|
||||
|
||||
if initial_state.ndim != 4:
|
||||
raise ValueError(
|
||||
f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
|
||||
)
|
||||
if initial_state.stride(-1) != 1:
|
||||
raise ValueError("`initial_state` must be contiguous in the last dim.")
|
||||
HV, V, K = initial_state.shape[-3:]
|
||||
if a.shape[1] != HV or b.shape[1] != HV:
|
||||
raise ValueError(
|
||||
f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
|
||||
)
|
||||
if A_log.numel() != HV or dt_bias.numel() != HV:
|
||||
raise ValueError(
|
||||
f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
|
||||
)
|
||||
if out.shape != (B, 1, HV, V):
|
||||
raise ValueError(
|
||||
f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
|
||||
)
|
||||
|
||||
qkv_dim = mixed_qkv.shape[1]
|
||||
qk_dim = qkv_dim - HV * V
|
||||
if qk_dim <= 0 or qk_dim % 2 != 0:
|
||||
raise ValueError(
|
||||
f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
|
||||
)
|
||||
q_dim = qk_dim // 2
|
||||
if q_dim % K != 0:
|
||||
raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
|
||||
H = q_dim // K
|
||||
if H <= 0 or HV % H != 0:
|
||||
raise ValueError(
|
||||
f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
|
||||
)
|
||||
|
||||
BK = triton.next_power_of_2(K)
|
||||
if triton.cdiv(K, BK) != 1:
|
||||
raise ValueError(
|
||||
f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
|
||||
)
|
||||
BV = min(triton.next_power_of_2(V), 32)
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
stride_mixed_qkv_tok = mixed_qkv.stride(0)
|
||||
stride_a_tok = a.stride(0)
|
||||
stride_b_tok = b.stride(0)
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = initial_state.stride(0)
|
||||
stride_indices_seq = ssm_state_indices.stride(0)
|
||||
|
||||
NV = triton.cdiv(V, BV)
|
||||
grid = (NV, B * HV)
|
||||
fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
|
||||
mixed_qkv=mixed_qkv,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
o=out,
|
||||
h0=initial_state,
|
||||
ht=initial_state,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
scale=scale,
|
||||
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
|
||||
stride_a_tok=stride_a_tok,
|
||||
stride_b_tok=stride_b_tok,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
SOFTPLUS_THRESHOLD=20.0,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
return out, initial_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -264,7 +490,7 @@ class FusedRecurrentFunction(torch.autograd.Function):
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -296,7 +522,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -324,7 +550,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
@@ -358,7 +584,7 @@ def fused_recurrent_gated_delta_rule(
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.int32)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
|
||||
279
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
Normal file
279
vllm/model_executor/layers/fla/ops/fused_sigmoid_gating.py
Normal file
@@ -0,0 +1,279 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
|
||||
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["N", "T"])
|
||||
def fused_sigmoid_gating_delta_rule_update_kernel(
|
||||
A_log,
|
||||
a,
|
||||
b,
|
||||
dt_bias,
|
||||
beta,
|
||||
threshold,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.int64, # num of sequences
|
||||
T: tl.int64, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
IS_KDA: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int64),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
|
||||
)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
|
||||
p_A_log = A_log + i_hv
|
||||
if not IS_KDA:
|
||||
p_a = a + bos * HV + i_hv
|
||||
p_dt_bias = dt_bias + i_hv
|
||||
else:
|
||||
p_a = a + (bos * HV + i_hv) * K + o_k
|
||||
p_dt_bias = dt_bias + i_hv * K + o_k
|
||||
|
||||
p_b = b + bos * HV + i_hv
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
# Load state index and check for invalid entries
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
tl.int64
|
||||
)
|
||||
# Skip if state index is invalid (NULL_BLOCK_ID=0)
|
||||
if state_idx <= 0:
|
||||
return
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * V * K
|
||||
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_b = tl.load(p_b).to(tl.float32)
|
||||
|
||||
# If the model is loaded in fp16, without the .float() here, A might be -inf
|
||||
x = tl.load(p_a).to(tl.float32) + tl.load(p_dt_bias).to(tl.float32)
|
||||
softplus_x = tl.where(
|
||||
beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x
|
||||
)
|
||||
b_g = -tl.exp(tl.load(p_A_log).to(tl.float32)) * softplus_x
|
||||
|
||||
# compute beta_output = sigmoid(b)
|
||||
b_beta = tl.sigmoid(b_b.to(tl.float32))
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q * (tl.rsqrt(tl.sum(b_q * b_q) + 1e-6))
|
||||
b_k = b_k * (tl.rsqrt(tl.sum(b_k * b_k) + 1e-6))
|
||||
b_q = b_q * scale
|
||||
# [BV, BK]
|
||||
if not IS_KDA:
|
||||
b_h *= tl.exp(b_g)
|
||||
else:
|
||||
b_h *= tl.exp(b_g[None, :])
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||
b_v *= b_beta
|
||||
# [BV, BK]
|
||||
b_h += b_v[:, None] * b_k[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
# Load state index and check for invalid entries
|
||||
final_state_idx = tl.load(
|
||||
ssm_state_indices + i_n * stride_indices_seq + i_t
|
||||
).to(tl.int64)
|
||||
# Only store if state index is valid (not NULL_BLOCK_ID=0)
|
||||
if final_state_idx > 0:
|
||||
p_ht = ht + final_state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
# Update pointers for next timestep
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
p_b += HV
|
||||
p_a += HV
|
||||
|
||||
|
||||
def fused_sigmoid_gating_delta_rule_update(
|
||||
A_log: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
dt_bias: torch.Tensor,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: float = 1.0,
|
||||
threshold: float = 20.0,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
is_kda: bool = False,
|
||||
):
|
||||
"""
|
||||
Fused triton implementation of sigmoid gating delta rule update.
|
||||
This function uses a single fused kernel that combines both sigmoid gating
|
||||
computation and the recurrent delta rule update for better performance.
|
||||
"""
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 4
|
||||
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]}"
|
||||
f" when using `cu_seqlens`. Please flatten variable-length"
|
||||
f" inputs before processing."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
|
||||
if ssm_state_indices is None:
|
||||
stride_indices_seq, stride_indices_tok = 1, 1
|
||||
elif ssm_state_indices.ndim == 1:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
|
||||
else:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
fused_sigmoid_gating_delta_rule_update_kernel[grid](
|
||||
A_log=A_log,
|
||||
a=a.contiguous(),
|
||||
b=b.contiguous(),
|
||||
dt_bias=dt_bias,
|
||||
beta=beta,
|
||||
threshold=threshold,
|
||||
q=q.contiguous(),
|
||||
k=k.contiguous(),
|
||||
v=v.contiguous(),
|
||||
o=o,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
scale=scale,
|
||||
N=N,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
stride_indices_tok=stride_indices_tok,
|
||||
INPLACE_FINAL_STATE=inplace_final_state,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
IS_KDA=is_kda,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o, final_state
|
||||
@@ -15,14 +15,12 @@ from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
def prepare_lens(cu_seqlens: torch.Tensor) -> torch.Tensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_indices(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
@@ -33,9 +31,7 @@ def prepare_chunk_indices(
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.Tensor, chunk_size: int) -> torch.Tensor:
|
||||
return torch.cat(
|
||||
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
|
||||
).cumsum(-1)
|
||||
|
||||
@@ -37,7 +37,7 @@ def fused_recurrent_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
@@ -115,7 +115,7 @@ def fused_recurrent_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
ssm_state_indices: torch.LongTensor | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -692,7 +692,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
gk: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -706,7 +706,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
gk (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
@@ -936,7 +936,7 @@ def recompute_w_u_fwd(
|
||||
A: torch.Tensor,
|
||||
q: torch.Tensor | None = None,
|
||||
gk: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
BT = A.shape[-1]
|
||||
@@ -1104,7 +1104,7 @@ def chunk_gla_fwd_o_gk(
|
||||
h: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
scale: float,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
):
|
||||
B, T, H, K, V = *q.shape, v.shape[-1]
|
||||
@@ -1148,7 +1148,7 @@ def chunk_kda_fwd(
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
@@ -1208,7 +1208,7 @@ def chunk_kda(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if scale is None:
|
||||
|
||||
@@ -84,6 +84,7 @@ def layer_norm_fwd_kernel(
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
ACTIVATION: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the starting row of X and Y it should compute.
|
||||
row_start = tl.program_id(0) * ROWS_PER_BLOCK
|
||||
@@ -112,7 +113,10 @@ def layer_norm_fwd_kernel(
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
||||
x *= z * tl.sigmoid(z)
|
||||
elif ACTIVATION == "sigmoid":
|
||||
x *= tl.sigmoid(z)
|
||||
|
||||
# Compute mean and variance per row (reduce along axis 1)
|
||||
if not IS_RMS_NORM:
|
||||
@@ -155,7 +159,10 @@ def layer_norm_fwd_kernel(
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
if ACTIVATION == "swish" or ACTIVATION == "silu":
|
||||
y *= z * tl.sigmoid(z)
|
||||
elif ACTIVATION == "sigmoid":
|
||||
y *= tl.sigmoid(z)
|
||||
|
||||
# Write output
|
||||
tl.store(Y_base, y, mask=mask)
|
||||
@@ -178,6 +185,7 @@ def layer_norm_fwd(
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
@@ -232,9 +240,12 @@ def layer_norm_fwd(
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
ROWS_PER_BLOCK=rows_per_block,
|
||||
HAS_BIAS=bias is not None,
|
||||
HAS_Z=z is not None,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
ACTIVATION=activation,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
@@ -252,6 +263,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
@@ -277,6 +289,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
activation=activation,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
@@ -284,6 +297,7 @@ class LayerNormFn(torch.autograd.Function):
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
ctx.activation = activation
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
@@ -296,17 +310,25 @@ def layernorm_fn(
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm, activation
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
activation: str = "swish",
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True, activation
|
||||
)
|
||||
|
||||
|
||||
@@ -359,6 +381,7 @@ class RMSNormGated(nn.Module):
|
||||
norm_before_gate: bool = False,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
activation: str = "swish",
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
@@ -366,6 +389,7 @@ class RMSNormGated(nn.Module):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.activation = activation
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
@@ -385,4 +409,5 @@ class RMSNormGated(nn.Module):
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
activation=self.activation,
|
||||
)
|
||||
|
||||
@@ -122,7 +122,7 @@ def recompute_w_u_fwd(
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.LongTensor | None,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
|
||||
Reference in New Issue
Block a user