【OPS】qwen3-next support triton chunk_gated_delta_rule ops (#4070)
### What this PR does / why we need it? qwen3-next suppot triton chunk_gated_delta_rule ops ### co-owners @OsirisDuan - vLLM version: v0.11.2 Signed-off-by: shiyuan680 <917935075@qq.com>
This commit is contained in:
2
.github/workflows/_e2e_test.yaml
vendored
2
.github/workflows/_e2e_test.yaml
vendored
@@ -276,7 +276,7 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
run: |
|
||||
. /usr/local/Ascend/ascend-toolkit/8.3.RC2/bisheng_toolkit/set_env.sh
|
||||
python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev20250914-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
|
||||
python3 -m pip install "https://vllm-ascend.obs.cn-north-4.myhuaweicloud.com/vllm-ascend/triton_ascend-3.2.0.dev2025110717-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl"
|
||||
|
||||
- name: Run vllm-project/vllm-ascend Qwen3 Next test
|
||||
working-directory: ./vllm-ascend
|
||||
|
||||
33
tests/e2e/multicard/test_chunk_gated_delta_rule.py
Normal file
33
tests/e2e/multicard/test_chunk_gated_delta_rule.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import torch
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
|
||||
|
||||
|
||||
class TestChunkGatedDeltaRule(PytestBase):
|
||||
|
||||
def test_triton_fusion_ops(self, mock_moe_env):
|
||||
q = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
|
||||
k = torch.randn(1, 17, 4, 128, dtype=torch.bfloat16).npu()
|
||||
v = torch.randn(1, 17, 8, 128, dtype=torch.bfloat16).npu()
|
||||
g = torch.randn(1, 17, 8, dtype=torch.float32).npu()
|
||||
beta = torch.randn(1, 17, 8, dtype=torch.bfloat16).npu()
|
||||
initial_state = torch.randn(3, 8, 128, 128, dtype=torch.bfloat16).npu()
|
||||
q_start_loc = torch.range(0, 3, dtype=torch.int).npu()
|
||||
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
) = chunk_gated_delta_rule(q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=q_start_loc,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True)
|
||||
|
||||
assert core_attn_out_non_spec.shape == (1, 17, 8, 128)
|
||||
assert last_recurrent_state.shape == (3, 8, 128, 128)
|
||||
@@ -423,50 +423,20 @@ class CustomQwen3NextGatedDeltaNet(Qwen3NextGatedDeltaNet, MambaBase):
|
||||
non_spec_state_indices_tensor].contiguous()
|
||||
initial_state[~has_initial_state, ...] = 0
|
||||
|
||||
batch_size = initial_state.shape[0]
|
||||
core_attn_out = []
|
||||
last_recurrent_state = []
|
||||
|
||||
for b_idx in range(batch_size):
|
||||
start, end = non_spec_query_start_loc[
|
||||
b_idx], non_spec_query_start_loc[b_idx + 1]
|
||||
cur_q = query_non_spec[:, start:end, ...]
|
||||
cur_k = key_non_spec[:, start:end, ...]
|
||||
cur_v = value_non_spec[:, start:end, ...]
|
||||
cur_g = g_non_spec[:, start:end, ...]
|
||||
cur_b = beta_non_spec[:, start:end, ...]
|
||||
cur_state = initial_state[b_idx].unsqueeze(0)
|
||||
|
||||
(
|
||||
cur_core_attn_out_non_spec,
|
||||
cur_last_recurrent_state,
|
||||
) = chunk.chunk_gated_delta_rule(
|
||||
query=cur_q,
|
||||
key=cur_k,
|
||||
value=cur_v,
|
||||
g=cur_g,
|
||||
beta=cur_b,
|
||||
initial_state=cur_state,
|
||||
output_final_state=True,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
|
||||
core_attn_out.append(cur_core_attn_out_non_spec)
|
||||
last_recurrent_state.append(cur_last_recurrent_state)
|
||||
|
||||
tar_dtype = core_attn_out[0].dtype
|
||||
tar_device = core_attn_out[0].device
|
||||
tar_shape = list(core_attn_out[0].shape)
|
||||
tar_shape[1] = non_spec_query_start_loc[-1]
|
||||
core_attn_out_non_spec = torch.empty(tar_shape,
|
||||
dtype=tar_dtype,
|
||||
device=tar_device)
|
||||
for b_idx in range(batch_size):
|
||||
cur_core_attn_out = core_attn_out[b_idx]
|
||||
start, end = non_spec_query_start_loc[
|
||||
b_idx], non_spec_query_start_loc[b_idx + 1]
|
||||
core_attn_out_non_spec[:, start:end, ...] = cur_core_attn_out
|
||||
last_recurrent_state = torch.cat(last_recurrent_state, dim=0)
|
||||
(
|
||||
core_attn_out_non_spec,
|
||||
last_recurrent_state,
|
||||
) = chunk.chunk_gated_delta_rule(
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True)
|
||||
|
||||
# Init cache
|
||||
ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to(
|
||||
|
||||
226
vllm_ascend/ops/triton/fla/chunk.py
Normal file
226
vllm_ascend/ops/triton/fla/chunk.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd
|
||||
from vllm.model_executor.layers.fla.ops.utils import SUPPRESS_LEVEL
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=torch.float32)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
return g, o, A, final_state, w, h, v_new
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
head_first (Optional[bool]):
|
||||
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
||||
Default: `False`.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
return o, final_state
|
||||
259
vllm_ascend/ops/triton/fla/chunk_delta_h.py
Normal file
259
vllm_ascend/ops/triton/fla/chunk_delta_h.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_indices, prepare_chunk_offsets, safe_exp
|
||||
|
||||
_CONDITIONS = ("seq7168", )
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_nh = tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
T_max = 1 * T
|
||||
if IS_VARLEN:
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
stride_v = H * V
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
|
||||
b_h1_bv1 = tl.zeros([128, 64], dtype=tl.float32)
|
||||
b_h1_bv2 = tl.zeros([128, 64], dtype=tl.float32)
|
||||
|
||||
v_start1 = 0
|
||||
v_start2 = 64
|
||||
|
||||
offs_k = tl.arange(0, 128)[:, None]
|
||||
offs_v1 = v_start1 + tl.arange(0, 64)[None, :]
|
||||
offs_v2 = v_start2 + tl.arange(0, 64)[None, :]
|
||||
mask_kv1 = (offs_k < K) & (offs_v1 < V)
|
||||
mask_kv2 = (offs_k < K) & (offs_v2 < V)
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
h0_ptr = h0 + i_nh * K * V
|
||||
ptr_h0_bv1 = h0_ptr + offs_k * V + offs_v1 * 1
|
||||
b_h1_bv1 += tl.load(ptr_h0_bv1, mask=mask_kv1,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
ptr_h0_bv2 = h0_ptr + offs_k * V + offs_v2 * 1
|
||||
b_h1_bv2 += tl.load(ptr_h0_bv2, mask=mask_kv2,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
h_base = h + (boh + i_t) * H * K * V + i_h * K * V
|
||||
|
||||
p_h1_bv1 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start1),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_h1_bv1,
|
||||
b_h1_bv1.to(p_h1_bv1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
p_h1_bv2 = tl.make_block_ptr(h_base, (K, V), (V, 1), (0, v_start2),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_h1_bv2,
|
||||
b_h1_bv2.to(p_h1_bv2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
offs_t_wv = (i_t * BT + tl.arange(0, BT))[:, None]
|
||||
offs_k_wv = tl.arange(0, 128)[None, :]
|
||||
mask_w = (offs_t_wv < T) & (offs_k_wv < K)
|
||||
|
||||
w_base = w + bos * H * K + i_h * K
|
||||
ptr_w = w_base + offs_t_wv * stride_w + offs_k_wv * 1
|
||||
b_w = tl.load(ptr_w, mask=mask_w, other=0.0)
|
||||
|
||||
k_base = k + bos * Hg * K + (i_h // (H // Hg)) * K
|
||||
p_k = tl.make_block_ptr(k_base, (K, T), (1, stride_k), (0, i_t * BT),
|
||||
(128, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
|
||||
v_new_base = v_new + bos * H * V + i_h * V
|
||||
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
b_g_last = tl.load(g + bos + i_h * T_max + last_idx)
|
||||
|
||||
offs_t = i_t * BT + tl.arange(0, BT)
|
||||
mask_t = offs_t < T
|
||||
g_ptr = g + bos + i_h * T_max
|
||||
b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0)
|
||||
|
||||
b_g = safe_exp(b_g_last - b_g)
|
||||
b_g_last = tl.exp(b_g_last)
|
||||
|
||||
offs_t_v = (i_t * BT + tl.arange(0, BT))[:, None]
|
||||
mask_v1 = (offs_t_v < T) & (offs_v1 < V)
|
||||
|
||||
v_base = v + bos * H * V + i_h * V
|
||||
ptr_v1 = v_base + offs_t_v * stride_v + offs_v1 * 1
|
||||
b_v1 = tl.load(ptr_v1, mask=mask_v1, other=0.0)
|
||||
b_v_new1 = b_v1.to(tl.float32)
|
||||
b_v_new1 -= tl.dot(b_w, b_h1_bv1.to(b_w.dtype))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new1 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
|
||||
(i_t * BT, v_start1), (BT, 64),
|
||||
(1, 0))
|
||||
tl.store(p_v_new1,
|
||||
b_v_new1.to(p_v_new1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
b_v_new1 = b_v_new1 * b_g[:, None]
|
||||
b_h1_bv1 = b_h1_bv1 * b_g_last
|
||||
|
||||
b_v_new1 = b_v_new1.to(k.dtype.element_ty)
|
||||
b_h1_bv1 += tl.dot(b_k, b_v_new1)
|
||||
|
||||
mask_v2 = (offs_t_v < T) & (offs_v2 < V)
|
||||
ptr_v2 = v_base + offs_t_v * stride_v + offs_v2 * 1
|
||||
b_v2 = tl.load(ptr_v2, mask=mask_v2, other=0.0)
|
||||
b_v_new2 = b_v2.to(tl.float32)
|
||||
b_v_new2 -= tl.dot(b_w, b_h1_bv2.to(b_w.dtype))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new2 = tl.make_block_ptr(v_new_base, (T, V), (stride_v, 1),
|
||||
(i_t * BT, v_start2), (BT, 64),
|
||||
(1, 0))
|
||||
tl.store(p_v_new2,
|
||||
b_v_new2.to(p_v_new2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
b_v_new2 = b_v_new2 * b_g[:, None]
|
||||
b_h1_bv2 = b_h1_bv2 * b_g_last
|
||||
|
||||
b_v_new2 = b_v_new2.to(k.dtype.element_ty)
|
||||
b_h1_bv2 += tl.dot(b_k, b_v_new2)
|
||||
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
ht_ptr = ht + i_nh * K * V
|
||||
|
||||
p_ht1_bv1 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start1),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv1,
|
||||
b_h1_bv1.to(p_ht1_bv1.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
p_ht1_bv2 = tl.make_block_ptr(ht_ptr, (K, V), (V, 1), (0, v_start2),
|
||||
(128, 64), (1, 0))
|
||||
tl.store(p_ht1_bv2,
|
||||
b_h1_bv2.to(p_ht1_bv2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None else None)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = (k.new_empty(N, H, K, V, dtype=torch.float32)
|
||||
if output_final_state else None)
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
|
||||
def grid(meta):
|
||||
return (1, N * H)
|
||||
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
return h, v_new, final_state
|
||||
168
vllm_ascend/ops/triton/fla/chunk_o.py
Normal file
168
vllm_ascend/ops/triton/fla/chunk_o.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_offsets, safe_exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
T_max = T
|
||||
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int64)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# offset calculation
|
||||
q += (bos * Hg + i_h // (H // Hg)) * K
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
|
||||
for i_t in range(NT):
|
||||
i_tg = boh + i_t
|
||||
h_base = h + (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1),
|
||||
(i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K),
|
||||
(i_k * BK, i_t * BT), (BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h_base, (K, V), (V, 1),
|
||||
(i_k * BK, i_v * BV), (BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
offs_t = i_t * BT + tl.arange(0, BT)
|
||||
mask_t = offs_t < T
|
||||
g_ptr = g + bos + i_h * T_max
|
||||
b_g = tl.load(g_ptr + offs_t, mask=mask_t, other=0.0)
|
||||
|
||||
b_o = b_o * tl.exp(b_g)[:, None]
|
||||
b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_i = tl.arange(0, BT).to(tl.float32)
|
||||
m_A = o_i[:, None] >= o_i[None, :]
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by fla v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
if cu_seqlens is None:
|
||||
N, chunk_offsets = B, None
|
||||
else:
|
||||
N, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), N * H)
|
||||
|
||||
g = g.transpose(1, 2).contiguous()
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
h=h,
|
||||
g=g,
|
||||
o=o,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
scale=scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=128,
|
||||
BV=128,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
)
|
||||
return o
|
||||
147
vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py
Normal file
147
vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_indices, safe_exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
'USE_G': lambda args: args['g_cumsum'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta, # [H, B, T]
|
||||
g_cumsum, # [H, B, T]
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
):
|
||||
bt_stride = B * T
|
||||
i_t_i, _ = tl.program_id(0), tl.program_id(1)
|
||||
|
||||
for i_bh in range(B * H):
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t_i * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
i_t = i_t_i
|
||||
o_t = tl.arange(0, BT)
|
||||
o_t_fp32 = o_t.to(tl.float32)
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_A += tl.dot(b_k, tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T, ),
|
||||
(1, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A *= safe_exp(b_g_diff)
|
||||
|
||||
b_A *= b_beta[:, None]
|
||||
b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0)
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||
gk (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
B, T, Hg, K = k.shape
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
if cu_seqlens is not None:
|
||||
cu_seqlens = cu_seqlens.cpu()
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
|
||||
if cu_seqlens is not None else None)
|
||||
chunk_indices = chunk_indices.npu()
|
||||
cu_seqlens = cu_seqlens.npu()
|
||||
else:
|
||||
chunk_indices = None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, 1)](
|
||||
k=k,
|
||||
beta=torch.permute(beta, (2, 0, 1)).contiguous(),
|
||||
g_cumsum=torch.permute(g_cumsum, (2, 0, 1)).contiguous(),
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
BT=BT,
|
||||
BK=128,
|
||||
num_warps=8,
|
||||
num_stages=3,
|
||||
multibuffer=True,
|
||||
)
|
||||
return A
|
||||
145
vllm_ascend/ops/triton/fla/cumsum.py
Normal file
145
vllm_ascend/ops/triton/fla/cumsum.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'HAS_SCALE': lambda args: args['scale'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
scale,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BLOCK_T: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
HAS_SCALE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
CHUNK_SIZE: tl.constexpr = 64,
|
||||
):
|
||||
i_block, i_b = tl.program_id(0), tl.program_id(1)
|
||||
N_CHUNKS: tl.constexpr = BLOCK_T // CHUNK_SIZE
|
||||
|
||||
if IS_VARLEN:
|
||||
i_s, i_block = tl.load(chunk_indices + i_block * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_block * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_s).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_s + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
ptr_s = tl.make_block_ptr(s + bos * H, (H, T), (T, 1),
|
||||
(0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
|
||||
ptr_o = tl.make_block_ptr(o + bos * H, (H, T), (T, 1),
|
||||
(0, i_block * BLOCK_T), (H, BLOCK_T), (1, 0))
|
||||
b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32)
|
||||
b_s = tl.reshape(b_s, (H, N_CHUNKS, CHUNK_SIZE))
|
||||
b_s = tl.trans(b_s, (2, 0, 1))
|
||||
b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
|
||||
if HAS_SCALE:
|
||||
b_o *= scale
|
||||
b_o = tl.trans(b_o, (2, 0, 1))
|
||||
b_o = tl.reshape(b_o, (H, BLOCK_T))
|
||||
else:
|
||||
ptr_s = tl.make_block_ptr(s + bos * H, (T, H), (H, 1),
|
||||
(i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
ptr_o = tl.make_block_ptr(o + bos * H, (T, H), (H, 1),
|
||||
(i_block * BLOCK_T, 0), (BLOCK_T, H), (1, 0))
|
||||
b_s = tl.load(ptr_s, boundary_check=(0, )).to(tl.float32)
|
||||
b_s = tl.reshape(b_s, (N_CHUNKS, CHUNK_SIZE, H))
|
||||
b_s = tl.trans(b_s, (1, 0, 2))
|
||||
b_o = tl.cumsum(b_s, axis=0, reverse=REVERSE)
|
||||
if HAS_SCALE:
|
||||
b_o *= scale
|
||||
b_o = tl.trans(b_o, (1, 0, 2))
|
||||
b_o = tl.reshape(b_o, (BLOCK_T, H))
|
||||
|
||||
tl.store(ptr_o, b_o.to(s.dtype.element_ty), boundary_check=(0, ))
|
||||
return
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g,
|
||||
chunk_size,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.Tensor] = torch.float,
|
||||
):
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
OPTIM_BLOCK_SIZE = triton.next_power_of_2((2**18) // (H * chunk_size))
|
||||
block_indices = prepare_chunk_indices(
|
||||
cu_seqlens,
|
||||
chunk_size=OPTIM_BLOCK_SIZE) if cu_seqlens is not None else None
|
||||
num_blocks = len(block_indices) if cu_seqlens is not None else triton.cdiv(
|
||||
T, OPTIM_BLOCK_SIZE)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (num_blocks, B)
|
||||
chunk_local_cumsum_scalar_kernel[grid](s=g_org,
|
||||
o=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=block_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BLOCK_T=OPTIM_BLOCK_SIZE,
|
||||
CHUNK_SIZE=chunk_size,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
num_warps=8,
|
||||
num_stages=3)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum(g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
scale: float = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs) -> torch.Tensor:
|
||||
if cu_seqlens is not None:
|
||||
assert g.shape[
|
||||
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(g=g,
|
||||
chunk_size=chunk_size,
|
||||
reverse=reverse,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
head_first=head_first,
|
||||
output_dtype=output_dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape {g.shape}, "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise")
|
||||
@@ -7,7 +7,6 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
MAX_CORES = 65535
|
||||
@@ -200,100 +199,3 @@ class LayerNormFn(torch.autograd.Function):
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def torch_chunk_gated_delta_rule(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
g,
|
||||
beta,
|
||||
chunk_size=64,
|
||||
initial_state=None,
|
||||
output_final_state=False,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
):
|
||||
initial_dtype = query.dtype
|
||||
if use_qk_l2norm_in_kernel:
|
||||
query = F.normalize(query, p=2, dim=-1)
|
||||
key = F.normalize(key, p=2, dim=-1)
|
||||
query, key, value, beta, g = [
|
||||
x.transpose(1, 2).contiguous().to(torch.float32)
|
||||
for x in (query, key, value, beta, g)
|
||||
]
|
||||
|
||||
batch_size, sequence_length, num_heads, k_head_dim = key.shape
|
||||
v_head_dim = value.shape[-1]
|
||||
pad_size = (chunk_size - num_heads % chunk_size) % chunk_size
|
||||
query = F.pad(query, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
key = F.pad(key, (0, 0, 0, pad_size)).repeat_interleave(2, dim=1)
|
||||
value = F.pad(value, (0, 0, 0, pad_size))
|
||||
beta = F.pad(beta, (0, pad_size))
|
||||
g = F.pad(g, (0, pad_size))
|
||||
tot_heads = num_heads + pad_size
|
||||
scale = 1 / (query.shape[-1]**0.5)
|
||||
query = query * scale
|
||||
|
||||
v_beta = value * beta.unsqueeze(-1)
|
||||
k_beta = key * beta.unsqueeze(-1)
|
||||
# reshape to chunks
|
||||
query, key, value, k_beta, v_beta = [
|
||||
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1])
|
||||
for x in (query, key, value, k_beta, v_beta)
|
||||
]
|
||||
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=0)
|
||||
|
||||
# chunk decay
|
||||
g = g.cumsum(dim=-1)
|
||||
decay_mask = ((g.unsqueeze(-1) -
|
||||
g.unsqueeze(-2)).tril().exp().float()).tril()
|
||||
attn = -(
|
||||
(k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
|
||||
for i in range(1, chunk_size):
|
||||
row = attn[..., i, :i].clone()
|
||||
sub = attn[..., :i, :i].clone()
|
||||
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
|
||||
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
|
||||
value = attn @ v_beta
|
||||
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
|
||||
|
||||
last_recurrent_state = (torch.zeros(batch_size, sequence_length,
|
||||
k_head_dim, v_head_dim).to(value) if
|
||||
initial_state is None else initial_state.to(value))
|
||||
|
||||
core_attn_out = torch.zeros_like(value)
|
||||
mask = torch.triu(torch.ones(chunk_size,
|
||||
chunk_size,
|
||||
dtype=torch.bool,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
|
||||
# for each chunk
|
||||
for i in range(0, tot_heads // chunk_size):
|
||||
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
|
||||
attn = (q_i @ k_i.transpose(-1, -2) *
|
||||
decay_mask[:, :, i]).masked_fill_(mask, 0)
|
||||
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
|
||||
v_new = v_i - v_prime
|
||||
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
|
||||
core_attn_out[:, :, i] = attn_inter + attn @ v_new
|
||||
last_recurrent_state = (
|
||||
last_recurrent_state * g[:, :, i, -1, None, None].exp() +
|
||||
(k_i *
|
||||
(g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(
|
||||
-1, -2) @ v_new)
|
||||
|
||||
if not output_final_state:
|
||||
last_recurrent_state = None
|
||||
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0],
|
||||
core_attn_out.shape[1], -1,
|
||||
core_attn_out.shape[-1])
|
||||
core_attn_out = core_attn_out[:, :, :num_heads]
|
||||
core_attn_out = core_attn_out.transpose(1,
|
||||
2).contiguous().to(initial_dtype)
|
||||
return core_attn_out, last_recurrent_state
|
||||
419
vllm_ascend/ops/triton/fla/solve_tril.py
Normal file
419
vllm_ascend/ops/triton/fla/solve_tril.py
Normal file
@@ -0,0 +1,419 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
LARGE_BLOCK_T: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ad = Ad + (bos * H + i_h) * 16
|
||||
|
||||
base_t = i_t * LARGE_BLOCK_T
|
||||
|
||||
NTASKS: tl.constexpr = 2
|
||||
N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS
|
||||
|
||||
for taskid in range(0, NTASKS):
|
||||
base_t += taskid * (LARGE_BLOCK_T // NTASKS)
|
||||
|
||||
# use make_block_ptr to reduce vector computation
|
||||
b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32)
|
||||
for blkid in range(0, N_BLOCKS):
|
||||
row_start_o = base_t + blkid * 16
|
||||
col_start_o = row_start_o % BT
|
||||
|
||||
# 1 Create in-block offset
|
||||
offs_rows_in_block = tl.arange(0, 16)
|
||||
offs_cols_in_block = tl.arange(0, 16)
|
||||
|
||||
# 2 Calculate the pointer of each element
|
||||
ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o +
|
||||
offs_rows_in_block[:, None] * H * BT +
|
||||
offs_cols_in_block[None, :])
|
||||
|
||||
# 3 Create a mask to prevent out-of-bounds access
|
||||
global_rows = row_start_o + offs_rows_in_block[:, None]
|
||||
global_cols = col_start_o + offs_cols_in_block[None, :]
|
||||
load_mask = (global_rows < T) & (global_cols < BT)
|
||||
|
||||
# 4 Use mask to safely load data
|
||||
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask,
|
||||
other=0.0).to(tl.float32)
|
||||
b_A = tl.insert_slice(
|
||||
ful=b_A,
|
||||
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
|
||||
offsets=[blkid, 0, 0],
|
||||
sizes=[1, 16, 16],
|
||||
strides=[1, 1, 1])
|
||||
|
||||
local_ori_A = tl.trans(b_A, (1, 0, 2))
|
||||
local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
|
||||
|
||||
# Convert mask into matrix multiplication to avoid for loops ub oom
|
||||
tmp = tl.arange(0, 16).to(tl.float32)
|
||||
rows = tmp[:, None]
|
||||
cols = tmp[None, :]
|
||||
is_lower = (rows > cols).to(b_A.dtype)
|
||||
b_A = -b_A * is_lower
|
||||
|
||||
# for loop to update N_BLOCKS row vector
|
||||
for i in range(1, 16):
|
||||
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0),
|
||||
(1, 16 * N_BLOCKS),
|
||||
(16 * N_BLOCKS, 1))
|
||||
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
|
||||
|
||||
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
|
||||
dot_product = tl.sum(dot_tmp, 0)
|
||||
b_a = b_a + dot_product
|
||||
|
||||
b_a_new_expanded = b_a[:, None, :]
|
||||
b_A = tl.insert_slice(ful=b_A,
|
||||
sub=b_a_new_expanded,
|
||||
offsets=[0, i, 0],
|
||||
sizes=[N_BLOCKS, 1, 16],
|
||||
strides=[1, 1, 1])
|
||||
|
||||
on_diagonal = (rows == cols)
|
||||
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
|
||||
|
||||
b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0),
|
||||
(N_BLOCKS * 16, 16), (1, 0))
|
||||
|
||||
# 1 Create in-block offset
|
||||
offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
|
||||
offs_cols_to_store = tl.arange(0, 16)
|
||||
|
||||
# 2 Calculate the pointer of each element
|
||||
p_Ai = (Ad + base_t * H * 16 + 0 +
|
||||
offs_rows_to_store[:, None] * H * 16 +
|
||||
offs_cols_to_store[None, :])
|
||||
# 3 Create a mask to prevent out-of-bounds access, only check rows
|
||||
global_store_rows = base_t + offs_rows_to_store[:, None]
|
||||
store_mask = global_store_rows < T
|
||||
# 4 use mask to save data safely
|
||||
tl.store(p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=store_mask)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def merge_16x16_to_32x32_inverse_kernel(
|
||||
A,
|
||||
Ad,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A += (bos * H + i_h) * 32
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 32
|
||||
|
||||
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_21 = -tl.dot(
|
||||
tl.dot(Ai_22, A_21, input_precision="ieee"),
|
||||
Ai_11,
|
||||
input_precision="ieee",
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_11,
|
||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_22,
|
||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_21,
|
||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def merge_16x16_to_64x64_inverse_kernel(
|
||||
A,
|
||||
Ad,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t_val = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
i_t = i_t_val
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
# Base pointers (already offset by batch and head)
|
||||
A += (bos * H + i_h) * 64
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 64
|
||||
|
||||
# load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16)
|
||||
offs_m = i_t * 64 + 16 + tl.arange(0, 16)
|
||||
offs_n = tl.arange(0, 16)
|
||||
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
||||
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
||||
Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
||||
|
||||
# load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16)
|
||||
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
||||
tmp = tl.dot(Ai_22, A_21, input_precision="ieee")
|
||||
|
||||
# load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16)
|
||||
offs_m = i_t * 64 + tl.arange(0, 16)
|
||||
offs_n = tl.arange(0, 16)
|
||||
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
||||
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
||||
Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
||||
|
||||
Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee")
|
||||
|
||||
# load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16)
|
||||
offs_m = i_t * 64 + 48 + tl.arange(0, 16)
|
||||
offs_n = tl.arange(0, 16)
|
||||
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
||||
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
||||
Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
||||
|
||||
# load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16)
|
||||
offs_n = 32 + tl.arange(0, 16)
|
||||
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
||||
tmp = tl.dot(Ai_44, A_43, input_precision="ieee")
|
||||
|
||||
# load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16)
|
||||
offs_m = i_t * 64 + 32 + tl.arange(0, 16)
|
||||
offs_n = tl.arange(0, 16)
|
||||
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
||||
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
||||
Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
||||
|
||||
Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee")
|
||||
|
||||
# build Ai_22_32 (32 * 32)
|
||||
Ai_22_32 = tl.zeros((32, 32), tl.float32)
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
|
||||
|
||||
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
|
||||
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
||||
offs_n = tl.arange(0, 32)
|
||||
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
||||
tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee")
|
||||
|
||||
# build Ai_11_32 (32 * 32)
|
||||
Ai_11_32 = tl.zeros((32, 32), tl.float32)
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
|
||||
|
||||
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")
|
||||
|
||||
# store Ai_11_32 to (i_t * 64, 0)
|
||||
offs_m = i_t * 64 + tl.arange(0, 32)
|
||||
offs_n = tl.arange(0, 32)
|
||||
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
tl.store(ptr_Ai,
|
||||
Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=mask_store)
|
||||
|
||||
# store Ai_22_32 to (i_t * 64 + 32, 32)
|
||||
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
||||
offs_n = 32 + tl.arange(0, 32)
|
||||
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
tl.store(ptr_Ai,
|
||||
Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=mask_store)
|
||||
|
||||
# store Ai_21_32 to (i_t * 64 + 32, 32)
|
||||
offs_n = tl.arange(0, 32)
|
||||
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
||||
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
||||
tl.store(ptr_Ai,
|
||||
Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
mask=mask_store)
|
||||
|
||||
# zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
|
||||
offs_m = i_t * 64 + tl.arange(0, 32)
|
||||
offs_n = 32 + tl.arange(0, 32)
|
||||
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT)
|
||||
ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :]
|
||||
zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty)
|
||||
tl.store(ptr_Ai, zero_block, mask=mask_store)
|
||||
|
||||
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse of the matrix I + A
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor. Default: `None`.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`.
|
||||
If `None`, the output dtype will be the same as the input dtype.
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
"""
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
Ad = torch.empty(B,
|
||||
T,
|
||||
H,
|
||||
16,
|
||||
device=A.device,
|
||||
dtype=torch.float if BT != 16 else output_dtype)
|
||||
|
||||
LARGE_BLOCK_T = 608 * 2
|
||||
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
|
||||
if cu_seqlens is not None else None)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(
|
||||
T, LARGE_BLOCK_T)
|
||||
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
LARGE_BLOCK_T=LARGE_BLOCK_T,
|
||||
num_warps=1,
|
||||
num_stages=4,
|
||||
)
|
||||
|
||||
if BT == 16:
|
||||
return Ad
|
||||
|
||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
||||
merge_fn = (merge_16x16_to_32x32_inverse_kernel
|
||||
if BT == 32 else merge_16x16_to_64x64_inverse_kernel)
|
||||
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
|
||||
if cu_seqlens is not None else None)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
|
||||
merge_fn[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
Ai=Ai,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
)
|
||||
return Ai
|
||||
79
vllm_ascend/ops/triton/fla/utils.py
Normal file
79
vllm_ascend/ops/triton/fla/utils.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
|
||||
1).to(cu_seqlens)
|
||||
|
||||
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([
|
||||
cu_seqlens.new_tensor([0]),
|
||||
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
|
||||
]).cumsum(-1)
|
||||
|
||||
|
||||
def input_guard(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else
|
||||
i.contiguous() for i in args)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = torch.npu.device(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@triton.jit
|
||||
def safe_exp(x):
|
||||
return tl.exp(tl.where(x <= 0, x, float("-inf")))
|
||||
131
vllm_ascend/ops/triton/fla/wy_fast.py
Normal file
131
vllm_ascend/ops/triton/fla/wy_fast.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
# mypy: ignore-errors
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, Hg: tl.constexpr,
|
||||
K: tl.constexpr, V: tl.constexpr,
|
||||
BT: tl.constexpr, BK: tl.constexpr,
|
||||
BV: tl.constexpr, IS_VARLEN: tl.constexpr):
|
||||
T_max = T
|
||||
i_t_o = tl.program_id(0)
|
||||
|
||||
for i_bh in range(H):
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t_o * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
offs_t = tl.arange(0, BT)
|
||||
global_offs_t = i_t * BT + offs_t
|
||||
mask_t = global_offs_t < T
|
||||
|
||||
offs_t_2d = global_offs_t[:, None]
|
||||
offs_bt = tl.arange(0, BT)[None, :]
|
||||
ptr_A = (A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1)
|
||||
mask_A = mask_t[:, None]
|
||||
b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
||||
|
||||
ptr_g = g + bos + i_h * T_max + global_offs_t
|
||||
b_g = tl.exp(tl.load(ptr_g, mask=mask_t, other=0.0)).to(tl.float32)
|
||||
|
||||
ptr_beta = beta + bos + i_h * T_max + global_offs_t
|
||||
b_beta = tl.load(ptr_beta, mask=mask_t, other=0.0).to(tl.float32)
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
offs_v = i_v * BV + tl.arange(0, BV)[None, :]
|
||||
mask_v = (mask_t[:, None]) & (offs_v < V)
|
||||
|
||||
ptr_v = (v + (bos * H + i_h) * V + offs_t_2d * (H * V) +
|
||||
offs_v * 1)
|
||||
b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32)
|
||||
|
||||
b_vb = (b_v * b_beta[:, None])
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
|
||||
ptr_u = (u + (bos * H + i_h) * V + offs_t_2d * (H * V) +
|
||||
offs_v * 1)
|
||||
tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
offs_k = i_k * BK + tl.arange(0, BK)[None, :]
|
||||
mask_k = (mask_t[:, None]) & (offs_k < K)
|
||||
ptr_k = (k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d *
|
||||
(Hg * K) + offs_k * 1)
|
||||
b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32)
|
||||
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None])
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
|
||||
ptr_w = (w + (bos * H + i_h) * K + offs_t_2d * (H * K) +
|
||||
offs_k * 1)
|
||||
tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k)
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT) \
|
||||
if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
BK = 64
|
||||
BV = 64
|
||||
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
beta = beta.transpose(1, 2).contiguous()
|
||||
g_cumsum = g_cumsum.transpose(1, 2).contiguous()
|
||||
recompute_w_u_fwd_kernel[(NT, B)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
num_warps=4,
|
||||
num_stages=3,
|
||||
)
|
||||
return w, u
|
||||
@@ -1,10 +1,7 @@
|
||||
import vllm.model_executor.layers.fla.ops.chunk
|
||||
import vllm.model_executor.layers.fla.ops.fused_recurrent
|
||||
import vllm.model_executor.layers.fla.ops.layernorm_guard
|
||||
import vllm.model_executor.layers.mamba.ops.causal_conv1d
|
||||
|
||||
from vllm_ascend.ops.triton.fla.fla import (LayerNormFn,
|
||||
torch_chunk_gated_delta_rule)
|
||||
from vllm_ascend.ops.triton.fla.chunk import chunk_gated_delta_rule
|
||||
from vllm_ascend.ops.triton.fla.layernorm_guard import LayerNormFn
|
||||
from vllm_ascend.ops.triton.fla.sigmoid_gating import \
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel
|
||||
from vllm_ascend.ops.triton.mamba.casual_conv1d import (
|
||||
@@ -14,4 +11,4 @@ vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_update = causal
|
||||
vllm.model_executor.layers.mamba.ops.causal_conv1d.causal_conv1d_fn = causal_conv1d_fn
|
||||
vllm.model_executor.layers.fla.ops.fused_recurrent.fused_recurrent_gated_delta_rule_fwd_kernel = fused_recurrent_gated_delta_rule_fwd_kernel
|
||||
vllm.model_executor.layers.fla.ops.layernorm_guard.LayerNormFn = LayerNormFn
|
||||
vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = torch_chunk_gated_delta_rule
|
||||
vllm.model_executor.layers.fla.ops.chunk.chunk_gated_delta_rule = chunk_gated_delta_rule
|
||||
|
||||
Reference in New Issue
Block a user