### What this PR does / why we need it? qwen3-next suppot triton chunk_gated_delta_rule ops ### co-owners @OsirisDuan - vLLM version: v0.11.2 Signed-off-by: shiyuan680 <917935075@qq.com>
420 lines
15 KiB
Python
420 lines
15 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
# mypy: ignore-errors
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from .utils import prepare_chunk_indices
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def solve_tril_16x16_kernel(
|
|
A,
|
|
Ad,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
LARGE_BLOCK_T: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
A = A + (bos * H + i_h) * BT
|
|
Ad = Ad + (bos * H + i_h) * 16
|
|
|
|
base_t = i_t * LARGE_BLOCK_T
|
|
|
|
NTASKS: tl.constexpr = 2
|
|
N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS
|
|
|
|
for taskid in range(0, NTASKS):
|
|
base_t += taskid * (LARGE_BLOCK_T // NTASKS)
|
|
|
|
# use make_block_ptr to reduce vector computation
|
|
b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32)
|
|
for blkid in range(0, N_BLOCKS):
|
|
row_start_o = base_t + blkid * 16
|
|
col_start_o = row_start_o % BT
|
|
|
|
# 1 Create in-block offset
|
|
offs_rows_in_block = tl.arange(0, 16)
|
|
offs_cols_in_block = tl.arange(0, 16)
|
|
|
|
# 2 Calculate the pointer of each element
|
|
ptr_A_subrec16 = (A + row_start_o * H * BT + col_start_o +
|
|
offs_rows_in_block[:, None] * H * BT +
|
|
offs_cols_in_block[None, :])
|
|
|
|
# 3 Create a mask to prevent out-of-bounds access
|
|
global_rows = row_start_o + offs_rows_in_block[:, None]
|
|
global_cols = col_start_o + offs_cols_in_block[None, :]
|
|
load_mask = (global_rows < T) & (global_cols < BT)
|
|
|
|
# 4 Use mask to safely load data
|
|
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask,
|
|
other=0.0).to(tl.float32)
|
|
b_A = tl.insert_slice(
|
|
ful=b_A,
|
|
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
|
|
offsets=[blkid, 0, 0],
|
|
sizes=[1, 16, 16],
|
|
strides=[1, 1, 1])
|
|
|
|
local_ori_A = tl.trans(b_A, (1, 0, 2))
|
|
local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
|
|
|
|
# Convert mask into matrix multiplication to avoid for loops ub oom
|
|
tmp = tl.arange(0, 16).to(tl.float32)
|
|
rows = tmp[:, None]
|
|
cols = tmp[None, :]
|
|
is_lower = (rows > cols).to(b_A.dtype)
|
|
b_A = -b_A * is_lower
|
|
|
|
# for loop to update N_BLOCKS row vector
|
|
for i in range(1, 16):
|
|
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0),
|
|
(1, 16 * N_BLOCKS),
|
|
(16 * N_BLOCKS, 1))
|
|
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
|
|
|
|
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
|
|
dot_product = tl.sum(dot_tmp, 0)
|
|
b_a = b_a + dot_product
|
|
|
|
b_a_new_expanded = b_a[:, None, :]
|
|
b_A = tl.insert_slice(ful=b_A,
|
|
sub=b_a_new_expanded,
|
|
offsets=[0, i, 0],
|
|
sizes=[N_BLOCKS, 1, 16],
|
|
strides=[1, 1, 1])
|
|
|
|
on_diagonal = (rows == cols)
|
|
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
|
|
|
|
b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
|
|
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0),
|
|
(N_BLOCKS * 16, 16), (1, 0))
|
|
|
|
# 1 Create in-block offset
|
|
offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
|
|
offs_cols_to_store = tl.arange(0, 16)
|
|
|
|
# 2 Calculate the pointer of each element
|
|
p_Ai = (Ad + base_t * H * 16 + 0 +
|
|
offs_rows_to_store[:, None] * H * 16 +
|
|
offs_cols_to_store[None, :])
|
|
# 3 Create a mask to prevent out-of-bounds access, only check rows
|
|
global_store_rows = base_t + offs_rows_to_store[:, None]
|
|
store_mask = global_store_rows < T
|
|
# 4 use mask to save data safely
|
|
tl.store(p_Ai,
|
|
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
mask=store_mask)
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def merge_16x16_to_32x32_inverse_kernel(
|
|
A,
|
|
Ad,
|
|
Ai,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
A += (bos * H + i_h) * 32
|
|
Ad += (bos * H + i_h) * 16
|
|
Ai += (bos * H + i_h) * 32
|
|
|
|
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
|
(16, 16), (1, 0))
|
|
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
|
|
(16, 16), (1, 0))
|
|
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
|
|
(16, 16), (1, 0))
|
|
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
|
|
(16, 16), (1, 0))
|
|
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
|
|
(16, 16), (1, 0))
|
|
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
|
(16, 16), (1, 0))
|
|
|
|
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
|
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
|
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
|
Ai_21 = -tl.dot(
|
|
tl.dot(Ai_22, A_21, input_precision="ieee"),
|
|
Ai_11,
|
|
input_precision="ieee",
|
|
)
|
|
tl.store(
|
|
p_Ai_11,
|
|
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
boundary_check=(0, 1),
|
|
)
|
|
tl.store(
|
|
p_Ai_22,
|
|
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
boundary_check=(0, 1),
|
|
)
|
|
tl.store(
|
|
p_Ai_21,
|
|
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
boundary_check=(0, 1),
|
|
)
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.jit(do_not_specialize=["T"])
|
|
def merge_16x16_to_64x64_inverse_kernel(
|
|
A,
|
|
Ad,
|
|
Ai,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H: tl.constexpr,
|
|
BT: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t_val = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
i_t = i_t_val
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
# Base pointers (already offset by batch and head)
|
|
A += (bos * H + i_h) * 64
|
|
Ad += (bos * H + i_h) * 16
|
|
Ai += (bos * H + i_h) * 64
|
|
|
|
# load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + 16 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
# load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16)
|
|
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
|
tmp = tl.dot(Ai_22, A_21, input_precision="ieee")
|
|
|
|
# load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee")
|
|
|
|
# load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + 48 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
# load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16)
|
|
offs_n = 32 + tl.arange(0, 16)
|
|
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
|
tmp = tl.dot(Ai_44, A_43, input_precision="ieee")
|
|
|
|
# load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + 32 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee")
|
|
|
|
# build Ai_22_32 (32 * 32)
|
|
Ai_22_32 = tl.zeros((32, 32), tl.float32)
|
|
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
|
|
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
|
|
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
|
|
|
|
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
|
|
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
|
offs_n = tl.arange(0, 32)
|
|
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
|
tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee")
|
|
|
|
# build Ai_11_32 (32 * 32)
|
|
Ai_11_32 = tl.zeros((32, 32), tl.float32)
|
|
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
|
|
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
|
|
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
|
|
|
|
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")
|
|
|
|
# store Ai_11_32 to (i_t * 64, 0)
|
|
offs_m = i_t * 64 + tl.arange(0, 32)
|
|
offs_n = tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
tl.store(ptr_Ai,
|
|
Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
mask=mask_store)
|
|
|
|
# store Ai_22_32 to (i_t * 64 + 32, 32)
|
|
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
|
offs_n = 32 + tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
tl.store(ptr_Ai,
|
|
Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
mask=mask_store)
|
|
|
|
# store Ai_21_32 to (i_t * 64 + 32, 32)
|
|
offs_n = tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
tl.store(ptr_Ai,
|
|
Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
mask=mask_store)
|
|
|
|
# zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
|
|
offs_m = i_t * 64 + tl.arange(0, 32)
|
|
offs_n = 32 + tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :]
|
|
zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty)
|
|
tl.store(ptr_Ai, zero_block, mask=mask_store)
|
|
|
|
|
|
def solve_tril(
|
|
A: torch.Tensor,
|
|
cu_seqlens: Optional[torch.Tensor] = None,
|
|
output_dtype: torch.dtype = torch.float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute the inverse of the matrix I + A
|
|
A should be strictly lower triangular, i.e., A.triu() == 0.
|
|
|
|
Args:
|
|
A (torch.Tensor):
|
|
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
|
cu_seqlens (torch.Tensor):
|
|
The cumulative sequence lengths of the input tensor. Default: `None`.
|
|
output_dtype (torch.dtype):
|
|
The dtype of the output tensor. Default: `torch.float`.
|
|
If `None`, the output dtype will be the same as the input dtype.
|
|
|
|
Returns:
|
|
(I + A)^-1 with the same shape as A
|
|
"""
|
|
assert A.shape[-1] in [16, 32, 64]
|
|
|
|
B, T, H, BT = A.shape
|
|
Ad = torch.empty(B,
|
|
T,
|
|
H,
|
|
16,
|
|
device=A.device,
|
|
dtype=torch.float if BT != 16 else output_dtype)
|
|
|
|
LARGE_BLOCK_T = 608 * 2
|
|
|
|
chunk_indices = (prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
|
|
if cu_seqlens is not None else None)
|
|
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(
|
|
T, LARGE_BLOCK_T)
|
|
|
|
solve_tril_16x16_kernel[NT, B * H](
|
|
A=A,
|
|
Ad=Ad,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
BT=BT,
|
|
LARGE_BLOCK_T=LARGE_BLOCK_T,
|
|
num_warps=1,
|
|
num_stages=4,
|
|
)
|
|
|
|
if BT == 16:
|
|
return Ad
|
|
|
|
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
|
merge_fn = (merge_16x16_to_32x32_inverse_kernel
|
|
if BT == 32 else merge_16x16_to_64x64_inverse_kernel)
|
|
chunk_indices = (prepare_chunk_indices(cu_seqlens, BT)
|
|
if cu_seqlens is not None else None)
|
|
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
|
|
|
merge_fn[NT, B * H](
|
|
A=A,
|
|
Ad=Ad,
|
|
Ai=Ai,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
BT=BT,
|
|
num_warps=4,
|
|
num_stages=3,
|
|
)
|
|
return Ai
|