Files
xc-llm-ascend/vllm_ascend/ops/triton/fla/chunk_scaled_dot_kkt.py
cvSoldier 2db33868a4 [kernel] Recompilation optimization triggered by triton function parameter optimization (#7645)
<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
- Please clarify why the changes are needed. For instance, the use case
and bug description.
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
main branch:https://github.com/vllm-project/vllm-ascend/pull/7483

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: cvSoldier <610496306@qq.com>
2026-03-26 16:31:34 +08:00

143 lines
4.6 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
import torch
from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_indices, safe_exp
@triton.heuristics(
{
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"USE_G": lambda args: args["g_cumsum"] is not None,
}
)
@triton.jit(do_not_specialize=["T", "B"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta, # [H, B, T]
g_cumsum, # [H, B, T]
A,
cu_seqlens,
chunk_indices,
T,
B,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
):
bt_stride = B * T
i_t_i, _ = tl.program_id(0), tl.program_id(1)
for i_bh in range(B * H):
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t_i * 2).to(tl.int32),
tl.load(chunk_indices + i_t_i * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
i_t = i_t_i
o_t = tl.arange(0, BT)
o_t_fp32 = o_t.to(tl.float32)
p_beta = tl.make_block_ptr(beta + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_A += tl.dot(b_k, tl.trans(b_k))
if USE_G:
p_g = tl.make_block_ptr(g_cumsum + i_h * bt_stride + bos, (T,), (1,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A *= safe_exp(b_g_diff)
b_A *= b_beta[:, None]
b_A = tl.where(o_t_fp32[:, None] > o_t_fp32[None, :], b_A, 0)
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
gk (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H, K]` applied to the key tensor. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
B, T, Hg, K = k.shape
H = beta.shape[-1]
BT = chunk_size
if cu_seqlens is not None and chunk_indices is None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(NT, 1)](
k=k,
beta=torch.permute(beta, (2, 0, 1)).contiguous(),
g_cumsum=torch.permute(g_cumsum, (2, 0, 1)).contiguous(),
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
B=B,
H=H,
Hg=Hg,
K=K,
BT=BT,
BK=128,
num_warps=8,
num_stages=3,
multibuffer=True,
)
return A