[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)

### What this PR does / why we need it?

This pull request refactors the dispatch mechanism for the
**triton-ascend-specific operators** `insert_slice`, `extract_slice`,
and `get_element` to ensure compatibility with both CANN 8.5 and 9.0.

A unified helper function, `_resolve_triton_ascend_op`, has been
introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function
dynamically resolves these operators by first attempting to import them
from the `triton.language.extra.cann.extension` module, which is present
in newer CANN versions. If that fails, it falls back to the standard
`triton.language` module.

This approach centralizes operator dispatch logic, allowing individual
Triton kernels to use these functions without being aware of the
underlying Triton/CANN version. All call sites have been updated to use
these new unified functions.

### Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring of operator implementations and does
not introduce any user-facing changes.

### How was this patch tested?

CI is expected to pass with existing tests.

**Testing Context:**
- vLLM version: v0.16.0
- vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7`

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-03 17:10:30 +08:00
committed by GitHub
parent cb893bcdb0
commit 700423156f
5 changed files with 78 additions and 40 deletions

View File

@@ -1,7 +1,7 @@
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num
@triton.jit
@@ -40,8 +40,8 @@ def _swiglu_quant_kernel(
# swiglu
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
cur_x = tl.load(x_ptr + x_offsets)
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
x1 = extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
out = x1 * tl.sigmoid(x1) * x2
# quant
@@ -50,7 +50,7 @@ def _swiglu_quant_kernel(
# store scale
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")

View File

@@ -12,6 +12,8 @@
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice
from .utils import prepare_chunk_indices
@@ -80,7 +82,7 @@ def solve_tril_16x16_kernel(
# 4 Use mask to safely load data
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
b_A = tl.insert_slice(
b_A = insert_slice(
ful=b_A,
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
offsets=[blkid, 0, 0],
@@ -100,7 +102,7 @@ def solve_tril_16x16_kernel(
# for loop to update N_BLOCKS row vector
for i in range(1, 16):
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
@@ -108,7 +110,7 @@ def solve_tril_16x16_kernel(
b_a = b_a + dot_product
b_a_new_expanded = b_a[:, None, :]
b_A = tl.insert_slice(
b_A = insert_slice(
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
)
@@ -276,9 +278,9 @@ def merge_16x16_to_64x64_inverse_kernel(
# build Ai_22_32 (32 * 32)
Ai_22_32 = tl.zeros((32, 32), tl.float32)
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
@@ -290,9 +292,9 @@ def merge_16x16_to_64x64_inverse_kernel(
# build Ai_11_32 (32 * 32)
Ai_11_32 = tl.zeros((32, 32), tl.float32)
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")

View File

@@ -20,7 +20,7 @@ import triton # type: ignore
import triton.language as tl # type: ignore
from vllm.utils.torch_utils import direct_register_custom_op
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
@triton.jit
@@ -79,13 +79,13 @@ def split_qkv_rmsnorm_rope_kernel(
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
x1 = extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
x2 = extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
@@ -95,14 +95,14 @@ def split_qkv_rmsnorm_rope_kernel(
roped_q2 = x2 * cos + x1 * sin
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
roped_q = tl.insert_slice(
roped_q = insert_slice(
roped_q,
roped_q1,
offsets=(0, 0),
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_q = tl.insert_slice(
roped_q = insert_slice(
roped_q,
roped_q2,
offsets=(0, HALF_HEAD_DIM),
@@ -145,13 +145,13 @@ def split_qkv_rmsnorm_rope_kernel(
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
x1 = tl.extract_slice(
x1 = extract_slice(
normalized_values,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
x2 = tl.extract_slice(
x2 = extract_slice(
normalized_values,
offsets=(0, HALF_HEAD_DIM),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
@@ -161,14 +161,14 @@ def split_qkv_rmsnorm_rope_kernel(
roped_k2 = x2 * cos + x1 * sin
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
roped_k = tl.insert_slice(
roped_k = insert_slice(
roped_k,
roped_k1,
offsets=(0, 0),
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
strides=(1, 1),
)
roped_k = tl.insert_slice(
roped_k = insert_slice(
roped_k,
roped_k2,
offsets=(0, HALF_HEAD_DIM),

View File

@@ -17,7 +17,7 @@
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num
def cal_grid_and_block_size(batch_size: int):
@@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton(
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos,))
target_argmax1 = tl.get_element(target_argmax_id, (pos,))
draft_token_id1 = get_element(draft_token_id, (pos,))
target_argmax1 = get_element(target_argmax_id, (pos,))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
@@ -109,10 +109,10 @@ def rejection_greedy_sample_triton(
num_draft_tokens = end_idx - start_idx
for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_draft_tokens, (pos,))
num_tokens1 = get_element(num_draft_tokens, (pos,))
rejected = False
start_idx1 = tl.get_element(start_idx, (pos,))
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos,))
start_idx1 = get_element(start_idx, (pos,))
is_greedy_mask1 = get_element(is_greedy_mask, (pos,))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
@@ -162,12 +162,12 @@ def rejection_random_sample_kernel(
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
not_greedy = get_element(not_greedy_mask, (req_i,))
if not_greedy:
rejected = False
start_idx = tl.get_element(start_idxs, (req_i,))
start_idx = get_element(start_idxs, (req_i,))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
@@ -219,9 +219,9 @@ def expand_kernel(
src_val = tl.where(src_val == replace_from, replace_to, src_val)
for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i,))
start_idx1 = tl.get_element(start_idx, (i,))
src_val1 = tl.get_element(src_val, (i,))
num_tokens1 = get_element(num_tokens, (i,))
start_idx1 = get_element(start_idx, (i,))
src_val1 = get_element(src_val, (i,))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1)
@@ -272,7 +272,7 @@ def sample_recovered_tokens_kernel(
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id,))
max_p = get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
@@ -297,7 +297,7 @@ def sample_recovered_tokens_kernel(
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id,))
max_p = get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
@@ -388,15 +388,15 @@ def rejection_random_sample_block_verify_kernel(
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
not_greedy = get_element(not_greedy_mask, (req_i,))
if not_greedy:
rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
start_idx = tl.get_element(start_idxs, (req_i,))
start_idx = get_element(start_idxs, (req_i,))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)

View File

@@ -1,10 +1,46 @@
from typing import Any
import torch
from vllm.triton_utils import HAS_TRITON, triton
from vllm.triton_utils import HAS_TRITON, tl, triton
_NUM_AICORE = -1
_NUM_VECTORCORE = -1
_extension_module = None
if HAS_TRITON:
try:
import triton.language.extra.cann.extension as _extension_module # type: ignore
except ImportError:
_extension_module = None
def _resolve_triton_ascend_op(op_name: str):
if not HAS_TRITON:
raise RuntimeError(f"Triton op '{op_name}' cannot be resolved because HAS_TRITON is False")
if _extension_module is not None:
extension_op = getattr(_extension_module, op_name, None)
if extension_op is not None:
return extension_op
tl_op = getattr(tl, op_name, None)
if tl_op is not None:
return tl_op
raise RuntimeError(
f"Failed to resolve Triton op '{op_name}': "
"neither triton.language.extra.cann.extension nor triton.language provides it."
)
if HAS_TRITON:
insert_slice = _resolve_triton_ascend_op("insert_slice")
extract_slice = _resolve_triton_ascend_op("extract_slice")
get_element = _resolve_triton_ascend_op("get_element")
else:
insert_slice = None
extract_slice = None
get_element = None
def init_device_properties_triton():