<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? - Please clarify why the changes are needed. For instance, the use case and bug description. Some parameters of Triton operators are unnecessarily modified with the "constexpr" modifier. When these parameters change, recompilation is triggered, which significantly affects the model performance. Therefore, these parameters need to be rectified. main branch:https://github.com/vllm-project/vllm-ascend/pull/7483 ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: cvSoldier <610496306@qq.com>
401 lines
14 KiB
Python
401 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
|
#
|
|
# This file contains code copied from the flash-linear-attention project.
|
|
# The original source code was licensed under the MIT license and included
|
|
# the following copyright notice:
|
|
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
|
# ruff: noqa: E501
|
|
# mypy: ignore-errors
|
|
|
|
import torch
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice
|
|
|
|
from .utils import prepare_chunk_indices
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.jit(do_not_specialize=["T", "H"])
|
|
def solve_tril_16x16_kernel(
|
|
A,
|
|
Ad,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H,
|
|
BT: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
LARGE_BLOCK_T: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
A = A + (bos * H + i_h) * BT
|
|
Ad = Ad + (bos * H + i_h) * 16
|
|
|
|
base_t = i_t * LARGE_BLOCK_T
|
|
|
|
NTASKS: tl.constexpr = 2
|
|
N_BLOCKS: tl.constexpr = LARGE_BLOCK_T // 16 // NTASKS
|
|
|
|
for taskid in range(0, NTASKS):
|
|
base_t += taskid * (LARGE_BLOCK_T // NTASKS)
|
|
|
|
# use make_block_ptr to reduce vector computation
|
|
b_A = tl.zeros((N_BLOCKS, 16, 16), dtype=tl.float32)
|
|
for blkid in range(0, N_BLOCKS):
|
|
row_start_o = base_t + blkid * 16
|
|
col_start_o = row_start_o % BT
|
|
|
|
# 1 Create in-block offset
|
|
offs_rows_in_block = tl.arange(0, 16)
|
|
offs_cols_in_block = tl.arange(0, 16)
|
|
|
|
# 2 Calculate the pointer of each element
|
|
ptr_A_subrec16 = (
|
|
A
|
|
+ row_start_o * H * BT
|
|
+ col_start_o
|
|
+ offs_rows_in_block[:, None] * H * BT
|
|
+ offs_cols_in_block[None, :]
|
|
)
|
|
|
|
# 3 Create a mask to prevent out-of-bounds access
|
|
global_rows = row_start_o + offs_rows_in_block[:, None]
|
|
global_cols = col_start_o + offs_cols_in_block[None, :]
|
|
load_mask = (global_rows < T) & (global_cols < BT)
|
|
|
|
# 4 Use mask to safely load data
|
|
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
|
|
b_A = insert_slice(
|
|
ful=b_A,
|
|
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
|
|
offsets=[blkid, 0, 0],
|
|
sizes=[1, 16, 16],
|
|
strides=[1, 1, 1],
|
|
)
|
|
|
|
local_ori_A = tl.trans(b_A, (1, 0, 2))
|
|
local_ori_A = tl.reshape(local_ori_A, (16, 16 * N_BLOCKS))
|
|
|
|
# Convert mask into matrix multiplication to avoid for loops ub oom
|
|
tmp = tl.arange(0, 16).to(tl.float32)
|
|
rows = tmp[:, None]
|
|
cols = tmp[None, :]
|
|
is_lower = (rows > cols).to(b_A.dtype)
|
|
b_A = -b_A * is_lower
|
|
|
|
# for loop to update N_BLOCKS row vector
|
|
for i in range(1, 16):
|
|
nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
|
|
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
|
|
|
|
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
|
|
dot_product = tl.sum(dot_tmp, 0)
|
|
b_a = b_a + dot_product
|
|
|
|
b_a_new_expanded = b_a[:, None, :]
|
|
b_A = insert_slice(
|
|
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
|
|
)
|
|
|
|
on_diagonal = rows == cols
|
|
b_A = tl.where(on_diagonal, b_A + 1.0, b_A)
|
|
|
|
b_A = tl.reshape(b_A, (N_BLOCKS * 16, 16))
|
|
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (base_t, 0), (N_BLOCKS * 16, 16), (1, 0))
|
|
|
|
# 1 Create in-block offset
|
|
offs_rows_to_store = tl.arange(0, N_BLOCKS * 16)
|
|
offs_cols_to_store = tl.arange(0, 16)
|
|
|
|
# 2 Calculate the pointer of each element
|
|
p_Ai = Ad + base_t * H * 16 + 0 + offs_rows_to_store[:, None] * H * 16 + offs_cols_to_store[None, :]
|
|
# 3 Create a mask to prevent out-of-bounds access, only check rows
|
|
global_store_rows = base_t + offs_rows_to_store[:, None]
|
|
store_mask = global_store_rows < T
|
|
# 4 use mask to save data safely
|
|
tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=store_mask)
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.jit(do_not_specialize=["T", "H"])
|
|
def merge_16x16_to_32x32_inverse_kernel(
|
|
A,
|
|
Ad,
|
|
Ai,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H,
|
|
BT: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
A += (bos * H + i_h) * 32
|
|
Ad += (bos * H + i_h) * 16
|
|
Ai += (bos * H + i_h) * 32
|
|
|
|
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
|
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
|
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
|
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0))
|
|
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0))
|
|
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0))
|
|
|
|
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
|
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
|
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
|
Ai_21 = -tl.dot(
|
|
tl.dot(Ai_22, A_21, input_precision="ieee"),
|
|
Ai_11,
|
|
input_precision="ieee",
|
|
)
|
|
tl.store(
|
|
p_Ai_11,
|
|
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
boundary_check=(0, 1),
|
|
)
|
|
tl.store(
|
|
p_Ai_22,
|
|
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
boundary_check=(0, 1),
|
|
)
|
|
tl.store(
|
|
p_Ai_21,
|
|
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
|
boundary_check=(0, 1),
|
|
)
|
|
|
|
|
|
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
|
@triton.jit(do_not_specialize=["T", "H"])
|
|
def merge_16x16_to_64x64_inverse_kernel(
|
|
A,
|
|
Ad,
|
|
Ai,
|
|
cu_seqlens,
|
|
chunk_indices,
|
|
T,
|
|
H,
|
|
BT: tl.constexpr,
|
|
IS_VARLEN: tl.constexpr,
|
|
):
|
|
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
|
i_b, i_h = i_bh // H, i_bh % H
|
|
if IS_VARLEN:
|
|
i_n, i_t_val = (
|
|
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
|
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
|
)
|
|
bos, eos = (
|
|
tl.load(cu_seqlens + i_n).to(tl.int32),
|
|
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
|
)
|
|
T = eos - bos
|
|
i_t = i_t_val
|
|
else:
|
|
bos, eos = i_b * T, i_b * T + T
|
|
|
|
# Base pointers (already offset by batch and head)
|
|
A += (bos * H + i_h) * 64
|
|
Ad += (bos * H + i_h) * 16
|
|
Ai += (bos * H + i_h) * 64
|
|
|
|
# load Ai_22 (Ad block at row i_t * 64 + 16, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + 16 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_22 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
# load A_21 (A block at row i_t * 64 + 16, col 0, 16 * 16)
|
|
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
A_21 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
|
tmp = tl.dot(Ai_22, A_21, input_precision="ieee")
|
|
|
|
# load Ai_11 (Ad block at row i_t * 64, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_11 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
Ai_21 = -tl.dot(tmp, Ai_11, input_precision="ieee")
|
|
|
|
# load Ai_44 (Ad block at row i_t * 64 + 48, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + 48 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_44 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
# load A_43 (Ad block at row i_t * 64 + 48, col 32, 16 * 16)
|
|
offs_n = 32 + tl.arange(0, 16)
|
|
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
A_43 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
|
tmp = tl.dot(Ai_44, A_43, input_precision="ieee")
|
|
|
|
# load Ai_33 (Ad block at row i_t * 64 + 32, col 0, 16 * 16)
|
|
offs_m = i_t * 64 + 32 + tl.arange(0, 16)
|
|
offs_n = tl.arange(0, 16)
|
|
mask_Ad = (offs_m[:, None] < T) & (offs_n[None, :] < 16)
|
|
ptr_Ad = Ad + offs_m[:, None] * (H * 16) + offs_n[None, :]
|
|
Ai_33 = tl.load(ptr_Ad, mask=mask_Ad, other=0.0).to(tl.float32)
|
|
|
|
Ai_43 = -tl.dot(tmp, Ai_33, input_precision="ieee")
|
|
|
|
# build Ai_22_32 (32 * 32)
|
|
Ai_22_32 = tl.zeros((32, 32), tl.float32)
|
|
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
|
|
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
|
|
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
|
|
|
|
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
|
|
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
|
offs_n = tl.arange(0, 32)
|
|
mask_A = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_A = A + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
A_21_32 = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
|
|
tmp = tl.dot(Ai_22_32, A_21_32, input_precision="ieee")
|
|
|
|
# build Ai_11_32 (32 * 32)
|
|
Ai_11_32 = tl.zeros((32, 32), tl.float32)
|
|
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
|
|
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
|
|
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
|
|
|
|
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")
|
|
|
|
# store Ai_11_32 to (i_t * 64, 0)
|
|
offs_m = i_t * 64 + tl.arange(0, 32)
|
|
offs_n = tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
tl.store(ptr_Ai, Ai_11_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
|
|
|
|
# store Ai_22_32 to (i_t * 64 + 32, 32)
|
|
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
|
offs_n = 32 + tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
tl.store(ptr_Ai, Ai_22_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
|
|
|
|
# store Ai_21_32 to (i_t * 64 + 32, 32)
|
|
offs_n = tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < 64)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * 64) + offs_n[None, :]
|
|
tl.store(ptr_Ai, Ai_21_32.to(ptr_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), mask=mask_store)
|
|
|
|
# zero out the upper-right 32 * 32 block (rows 0 ~ 31, cols 32 ~ 63)
|
|
offs_m = i_t * 64 + tl.arange(0, 32)
|
|
offs_n = 32 + tl.arange(0, 32)
|
|
mask_store = (offs_m[:, None] < T) & (offs_n[None, :] < BT)
|
|
ptr_Ai = Ai + offs_m[:, None] * (H * BT) + offs_n[None, :]
|
|
zero_block = tl.zeros((32, 32), dtype=ptr_Ai.dtype.element_ty)
|
|
tl.store(ptr_Ai, zero_block, mask=mask_store)
|
|
|
|
|
|
def solve_tril(
|
|
A: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
chunk_indices_large_block: torch.Tensor | None = None,
|
|
chunk_indices_bt: torch.Tensor | None = None,
|
|
output_dtype: torch.dtype = torch.float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute the inverse of the matrix I + A
|
|
A should be strictly lower triangular, i.e., A.triu() == 0.
|
|
|
|
Args:
|
|
A (torch.Tensor):
|
|
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
|
cu_seqlens (torch.Tensor):
|
|
The cumulative sequence lengths of the input tensor. Default: `None`.
|
|
output_dtype (torch.dtype):
|
|
The dtype of the output tensor. Default: `torch.float`.
|
|
If `None`, the output dtype will be the same as the input dtype.
|
|
|
|
Returns:
|
|
(I + A)^-1 with the same shape as A
|
|
"""
|
|
assert A.shape[-1] in [16, 32, 64]
|
|
|
|
B, T, H, BT = A.shape
|
|
Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype)
|
|
|
|
LARGE_BLOCK_T = 608 * 2
|
|
|
|
if cu_seqlens is not None and chunk_indices_large_block is None:
|
|
chunk_indices_large_block = prepare_chunk_indices(cu_seqlens, LARGE_BLOCK_T)
|
|
chunk_indices = chunk_indices_large_block
|
|
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, LARGE_BLOCK_T)
|
|
|
|
solve_tril_16x16_kernel[NT, B * H](
|
|
A=A,
|
|
Ad=Ad,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
BT=BT,
|
|
LARGE_BLOCK_T=LARGE_BLOCK_T,
|
|
num_warps=1,
|
|
num_stages=4,
|
|
)
|
|
|
|
if BT == 16:
|
|
return Ad
|
|
|
|
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
|
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
|
if cu_seqlens is not None and chunk_indices_bt is None:
|
|
chunk_indices_bt = prepare_chunk_indices(cu_seqlens, BT)
|
|
chunk_indices = chunk_indices_bt
|
|
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
|
|
|
merge_fn[NT, B * H](
|
|
A=A,
|
|
Ad=Ad,
|
|
Ai=Ai,
|
|
cu_seqlens=cu_seqlens,
|
|
chunk_indices=chunk_indices,
|
|
T=T,
|
|
H=H,
|
|
BT=BT,
|
|
num_warps=4,
|
|
num_stages=3,
|
|
)
|
|
return Ai
|