Files
xc-llm-ascend/vllm_ascend/ops/triton/fla/wy_fast.py
Mr.WXS dba34d4915 [v0.18.0][Triton][Qwen3.5] delete expr for kernels args (#7646)
### What this PR does / why we need it?
Some parameters of Triton operators are unnecessarily modified with the
"constexpr" modifier. When these parameters change, recompilation is
triggered, which significantly affects the model performance. Therefore,
these parameters need to be rectified.
backport: https://github.com/vllm-project/vllm-ascend/pull/7482


Signed-off-by: w30012745 <wangxiaoshuai2@h-partners.com>
Co-authored-by: w30012745 <wangxiaoshuai2@h-partners.com>
2026-03-25 23:31:27 +08:00

144 lines
4.2 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
import torch
from vllm.triton_utils import tl, triton
from .utils import prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.jit(do_not_specialize=["T", "H", "Hg", "K", "V"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
cu_seqlens,
chunk_indices,
T,
H,
Hg,
K,
V,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
T_max = T
i_t_o = tl.program_id(0)
for i_bh in range(H):
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t_o * 2).to(tl.int32),
tl.load(chunk_indices + i_t_o * 2 + 1).to(tl.int32),
)
bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
offs_t = tl.arange(0, BT)
global_offs_t = i_t * BT + offs_t
mask_t = global_offs_t < T
offs_t_2d = global_offs_t[:, None]
offs_bt = tl.arange(0, BT)[None, :]
ptr_A = A + (bos * H + i_h) * BT + offs_t_2d * (H * BT) + offs_bt * 1
mask_A = mask_t[:, None]
b_A = tl.load(ptr_A, mask=mask_A, other=0.0).to(tl.float32)
ptr_g = g + bos + i_h * T_max + global_offs_t
b_g = tl.exp(tl.load(ptr_g, mask=mask_t, other=0.0)).to(tl.float32)
ptr_beta = beta + bos + i_h * T_max + global_offs_t
b_beta = tl.load(ptr_beta, mask=mask_t, other=0.0).to(tl.float32)
for i_v in range(tl.cdiv(V, BV)):
offs_v = i_v * BV + tl.arange(0, BV)[None, :]
mask_v = (mask_t[:, None]) & (offs_v < V)
ptr_v = v + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
b_v = tl.load(ptr_v, mask=mask_v, other=0.0).to(tl.float32)
b_vb = b_v * b_beta[:, None]
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
ptr_u = u + (bos * H + i_h) * V + offs_t_2d * (H * V) + offs_v * 1
tl.store(ptr_u, b_u.to(ptr_u.dtype.element_ty), mask=mask_v)
for i_k in range(tl.cdiv(K, BK)):
offs_k = i_k * BK + tl.arange(0, BK)[None, :]
mask_k = (mask_t[:, None]) & (offs_k < K)
ptr_k = k + (bos * Hg + i_h // (H // Hg)) * K + offs_t_2d * (Hg * K) + offs_k * 1
b_k = tl.load(ptr_k, mask=mask_k, other=0.0).to(tl.float32)
b_kb = b_k * b_beta[:, None] * b_g[:, None]
b_w = tl.dot(b_A, b_kb)
ptr_w = w + (bos * H + i_h) * K + offs_t_2d * (H * K) + offs_k * 1
tl.store(ptr_w, b_w.to(ptr_w.dtype.element_ty), mask=mask_k)
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.LongTensor | None = None,
chunk_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
if cu_seqlens is not None and chunk_indices is None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
beta = beta.transpose(1, 2).contiguous()
g_cumsum = g_cumsum.transpose(1, 2).contiguous()
recompute_w_u_fwd_kernel[(NT, B)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
num_warps=4,
num_stages=3,
)
return w, u