Files
xc-llm-ascend/vllm_ascend/ops/triton/fla/fused_qkvzba_split_reshape.py
meihanc 16b1bee804 [bugfix] fix test_camem failed with triton-ascend (#5492)
### What this PR does / why we need it?
This fixes a bug that occurred when running `test_camem.py` in the
triton-ascend environment `NPU function error:
aclrtGetMemInfo(ACL_HBM_MEM, &device_free, &device_total)`

- vLLM version: v0.13.0
- vLLM main:
5326c89803

---------

Signed-off-by: Meihan-chen <jcccx.cmh@gmail.com>
2026-01-05 20:10:54 +08:00

116 lines
4.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
# mypy: ignore-errors
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def fused_qkvzba_split_reshape_cat_kernel(
mixed_qkv,
z,
b,
a,
mixed_qkvz,
mixed_ba,
NUM_HEADS_QK: tl.constexpr,
NUM_HEADS_V: tl.constexpr,
HEAD_QK: tl.constexpr,
HEAD_V: tl.constexpr,
):
i_bs, i_qk = tl.program_id(0), tl.program_id(1)
QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2
BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2
QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
q_end: tl.constexpr = HEAD_QK
blk_q_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(0, q_end))
k_end: tl.constexpr = q_end + HEAD_QK
blk_k_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(q_end, k_end))
v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_v_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(k_end, v_end))
z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V
blk_z_ptr = (mixed_qkvz + i_bs * NUM_HEADS_QK * QKVZ_DIM_T +
i_qk * QKVZ_DIM_T + tl.arange(v_end, z_end))
blk_q_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
i_qk * HEAD_QK + tl.arange(0, HEAD_QK))
blk_k_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
NUM_HEADS_QK * HEAD_QK + i_qk * HEAD_QK +
tl.arange(0, HEAD_QK))
blk_v_st_ptr = (mixed_qkv + i_bs * NUM_HEADS_QK * QKV_DIM_T +
NUM_HEADS_QK * HEAD_QK * 2 +
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
blk_z_st_ptr = (z + i_bs * NUM_HEADS_V * HEAD_V +
i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK +
tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK))
tl.store(blk_q_st_ptr, tl.load(blk_q_ptr))
tl.store(blk_k_st_ptr, tl.load(blk_k_ptr))
tl.store(blk_v_st_ptr, tl.load(blk_v_ptr))
tl.store(blk_z_st_ptr, tl.load(blk_z_ptr))
b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK
a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK
for i in tl.static_range(b_end):
blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i
tl.store(blk_b_st_ptr, tl.load(blk_b_ptr))
for i in tl.static_range(b_end, a_end):
blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i
blk_a_st_ptr = (a + i_bs * NUM_HEADS_V +
i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end))
tl.store(blk_a_st_ptr, tl.load(blk_a_ptr))
def fused_qkvzba_split_reshape_cat(
mixed_qkvz,
mixed_ba,
num_heads_qk,
num_heads_v,
head_qk,
head_v,
):
batch, seq_len = mixed_qkvz.shape[0], 1
qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v
mixed_qkv = torch.empty(
[batch * seq_len, qkv_dim_t],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
z = torch.empty(
[batch * seq_len, num_heads_v, head_v],
dtype=mixed_qkvz.dtype,
device=mixed_qkvz.device,
)
b = torch.empty(
[batch * seq_len, num_heads_v],
dtype=mixed_ba.dtype,
device=mixed_ba.device,
)
a = torch.empty_like(b)
grid = (batch * seq_len, num_heads_qk)
fused_qkvzba_split_reshape_cat_kernel[grid](
mixed_qkv,
z,
b,
a,
mixed_qkvz,
mixed_ba,
num_heads_qk,
num_heads_v,
head_qk,
head_v,
num_warps=1,
num_stages=3,
)
return mixed_qkv, z, b, a