[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)

### What this PR does / why we need it?

This pull request refactors the dispatch mechanism for the
**triton-ascend-specific operators** `insert_slice`, `extract_slice`,
and `get_element` to ensure compatibility with both CANN 8.5 and 9.0.

A unified helper function, `_resolve_triton_ascend_op`, has been
introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function
dynamically resolves these operators by first attempting to import them
from the `triton.language.extra.cann.extension` module, which is present
in newer CANN versions. If that fails, it falls back to the standard
`triton.language` module.

This approach centralizes operator dispatch logic, allowing individual
Triton kernels to use these functions without being aware of the
underlying Triton/CANN version. All call sites have been updated to use
these new unified functions.

### Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring of operator implementations and does
not introduce any user-facing changes.

### How was this patch tested?

CI is expected to pass with existing tests.

**Testing Context:**
- vLLM version: v0.16.0
- vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7`

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-03 17:10:30 +08:00
committed by GitHub
parent cb893bcdb0
commit 700423156f
5 changed files with 78 additions and 40 deletions

View File

@@ -12,6 +12,8 @@
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice
from .utils import prepare_chunk_indices
@@ -80,7 +82,7 @@ def solve_tril_16x16_kernel(
# 4 Use mask to safely load data
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
b_A = tl.insert_slice(
b_A = insert_slice(
ful=b_A,
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
offsets=[blkid, 0, 0],
@@ -100,7 +102,7 @@ def solve_tril_16x16_kernel(
# for loop to update N_BLOCKS row vector
for i in range(1, 16):
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
@@ -108,7 +110,7 @@ def solve_tril_16x16_kernel(
b_a = b_a + dot_product
b_a_new_expanded = b_a[:, None, :]
b_A = tl.insert_slice(
b_A = insert_slice(
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
)
@@ -276,9 +278,9 @@ def merge_16x16_to_64x64_inverse_kernel(
# build Ai_22_32 (32 * 32)
Ai_22_32 = tl.zeros((32, 32), tl.float32)
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
@@ -290,9 +292,9 @@ def merge_16x16_to_64x64_inverse_kernel(
# build Ai_11_32 (32 * 32)
Ai_11_32 = tl.zeros((32, 32), tl.float32)
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")