Files
xc-llm-ascend/vllm_ascend/ops/triton/fla/solve_tril.py
cvSoldier 2db33868a4 [kernel] Recompilation optimization triggered by triton function parameter optimization (#7645)
<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
- Please clarify why the changes are needed. For instance, the use case
and bug description.
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
main branch:https://github.com/vllm-project/vllm-ascend/pull/7483

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: cvSoldier <610496306@qq.com>
2026-03-26 16:31:34 +08:00

401 lines
14 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice
from .utils import prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "H"])
def solve_tril_16x16_kernel(
A,
Ad,
cu_seqlens,
chunk_indices,
T,
H,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
LARGE_BLOCK_T: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
A = A + (bos * H + i_h) * BT
Ad = Ad + (bos * H + i_h) * 16
base_t = i_t * LARGE_BLOCK_T
NTASKS: tl.constexpr = 2
N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS
for taskid in range(0, NTASKS):
base_t += taskid * (LARGE_BLOCK_T // NTASKS)
# use make_block_ptr to reduce vector computation
b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32)
for blkid in range(0, N_BLOCKS):
row_start_o = base_t + blkid * 16
col_start_o = row_start_o % BT
# 1 Create in-block offset
offs_rows_in_block = tl.arange(0, 16)
offs_cols_in_block = tl.arange(0, 16)
# 2 Calculate the pointer of each element
ptr_A_subrec16 = (
A
+ row_start_o * H * BT
+ col_start_o
+ offs_rows_in_block[:, None] * H * BT
+ offs_cols_in_block[None, :]
)
# 3 Create a mask to prevent out-of-bounds access
global_rows = row_start_o + offs_rows_in_block[:, None]
global_cols = col_start_o + offs_cols_in_block[None, :]
load_mask = (global_rows < T) & (global_cols < BT)
# 4 Use mask to safely load data
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
b_A = insert_slice(
ful=b_A,
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
offsets=[blkid, 0, 0],
sizes=[1, 16, 16],
strides=[1, 1, 1],
)
local_ori_A = tl.trans(b_A, (1, 0, 2))
local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
# Convert mask into matrix multiplication to avoid for loops ub oom
tmp = tl.arange(0, 16).to(tl.float32)
rows = tmp[:, None]
cols = tmp[None, :]
is_lower = (rows > cols).to(b_A.dtype)
b_A = -b_A * is_lower
# for loop to update N_BLOCKS row vector
for i in range(1, 16):
nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
dot_product = tl.sum(dot_tmp, 0)
b_a = b_a + dot_product
b_a_new_expanded = b_a[:, None, :]
b_A = insert_slice(
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
)
on_diagonal = rows == cols
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0))
# 1 Create in-block offset
offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
offs_cols_to_store = tl.arange(0, 16)
# 2 Calculate the pointer of each element
p_Ai = Ad + base_t * H * 16 + 0 + offs_rows_to_store[:, None] * H * 16 + offs_cols_to_store[None, :]
# 3 Create a mask to prevent out-of-bounds access, only check rows
global_store_rows = base_t + offs_rows_to_store[:, None]
store_mask = global_store_rows < T
# 4 use mask to save data safely
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=store_mask)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "H"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ad,
Ai,
cu_seqlens,
chunk_indices,
T,
H,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
A += (bos * H + i_h) * 32
Ad += (bos * H + i_h) * 16
Ai += (bos * H + i_h) * 32
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0))
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0))
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
Ai_21 = -tl.dot(
tl.dot(Ai_22, A_21, input_precision="ieee"),
Ai_11,
input_precision="ieee",
)
tl.store(
p_Ai_11,
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_22,
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_21,
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "H"])
def merge_16x16_to_64x64_inverse_kernel(
A,
Ad,
Ai,
cu_seqlens,
chunk_indices,
T,
H,
BT: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t_val = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
i_t = i_t_val
else:
bos, eos = i_b * T, i_b * T + T
# Base pointers (already offset by batch and head)
A += (bos * H + i_h) * 64
Ad += (bos * H + i_h) * 16
Ai += (bos * H + i_h) * 64
# load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16)
offs_m = i_t * 64 + 16 + tl.arange(0, 16)
offs_n = tl.arange(0, 16)
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
# load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16)
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
tmp = tl.dot(Ai_22, A_21, input_precision="ieee")
# load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16)
offs_m = i_t * 64 + tl.arange(0, 16)
offs_n = tl.arange(0, 16)
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee")
# load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16)
offs_m = i_t * 64 + 48 + tl.arange(0, 16)
offs_n = tl.arange(0, 16)
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
# load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16)
offs_n = 32 + tl.arange(0, 16)
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
tmp = tl.dot(Ai_44, A_43, input_precision="ieee")
# load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16)
offs_m = i_t * 64 + 32 + tl.arange(0, 16)
offs_n = tl.arange(0, 16)
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee")
# build Ai_22_32 (32 * 32)
Ai_22_32 = tl.zeros((32, 32), tl.float32)
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
offs_n = tl.arange(0, 32)
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee")
# build Ai_11_32 (32 * 32)
Ai_11_32 = tl.zeros((32, 32), tl.float32)
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")
# store Ai_11_32 to (i_t * 64, 0)
offs_m = i_t * 64 + tl.arange(0, 32)
offs_n = tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai, Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# store Ai_22_32 to (i_t * 64 + 32, 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
offs_n = 32 + tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai, Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# store Ai_21_32 to (i_t * 64 + 32, 32)
offs_n = tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
tl.store(ptr_Ai, Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
# zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
offs_m = i_t * 64 + tl.arange(0, 32)
offs_n = 32 + tl.arange(0, 32)
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT)
ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :]
zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty)
tl.store(ptr_Ai, zero_block, mask=mask_store)
def solve_tril(
A: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
chunk_indices_large_block: torch.Tensor | None = None,
chunk_indices_bt: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
"""
assert A.shape[-1] in [16, 32, 64]
B, T, H, BT = A.shape
Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
LARGE_BLOCK_T = 608 * 2
if cu_seqlens is not None and chunk_indices_large_block is None:
chunk_indices_large_block = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
chunk_indices = chunk_indices_large_block
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
solve_tril_16x16_kernel[NT, B * H](
A=A,
Ad=Ad,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
LARGE_BLOCK_T=LARGE_BLOCK_T,
num_warps=1,
num_stages=4,
)
if BT == 16:
return Ad
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
if cu_seqlens is not None and chunk_indices_bt is None:
chunk_indices_bt = prepare_chunk_indices(cu_seqlens, BT)
chunk_indices = chunk_indices_bt
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
merge_fn[NT, B * H](
A=A,
Ad=Ad,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
num_warps=4,
num_stages=3,
)
return Ai