[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)

### What this PR does / why we need it?

This pull request refactors the dispatch mechanism for the
**triton-ascend-specific operators** `insert_slice`, `extract_slice`,
and `get_element` to ensure compatibility with both CANN 8.5 and 9.0.

A unified helper function, `_resolve_triton_ascend_op`, has been
introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function
dynamically resolves these operators by first attempting to import them
from the `triton.language.extra.cann.extension` module, which is present
in newer CANN versions. If that fails, it falls back to the standard
`triton.language` module.

This approach centralizes operator dispatch logic, allowing individual
Triton kernels to use these functions without being aware of the
underlying Triton/CANN version. All call sites have been updated to use
these new unified functions.

### Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring of operator implementations and does
not introduce any user-facing changes.

### How was this patch tested?

CI is expected to pass with existing tests.

**Testing Context:**
- vLLM version: v0.16.0
- vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7`

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-03 17:10:30 +08:00
committed by GitHub
parent cb893bcdb0
commit 700423156f
5 changed files with 78 additions and 40 deletions

View File

@@ -17,7 +17,7 @@
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num
def cal_grid_and_block_size(batch_size: int):
@@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton(
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
for pos in tl.range(0, BLOCK_SIZE):
draft_token_id1 = tl.get_element(draft_token_id, (pos,))
target_argmax1 = tl.get_element(target_argmax_id, (pos,))
draft_token_id1 = get_element(draft_token_id, (pos,))
target_argmax1 = get_element(target_argmax_id, (pos,))
position = block_idx * BLOCK_SIZE + pos
if draft_token_id1 == target_argmax1:
bonus_renew_1(
@@ -109,10 +109,10 @@ def rejection_greedy_sample_triton(
num_draft_tokens = end_idx - start_idx
for pos in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_draft_tokens, (pos,))
num_tokens1 = get_element(num_draft_tokens, (pos,))
rejected = False
start_idx1 = tl.get_element(start_idx, (pos,))
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos,))
start_idx1 = get_element(start_idx, (pos,))
is_greedy_mask1 = get_element(is_greedy_mask, (pos,))
position = block_idx * BLOCK_SIZE + pos
for i in range(num_tokens1):
if not rejected:
@@ -162,12 +162,12 @@ def rejection_random_sample_kernel(
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
not_greedy = get_element(not_greedy_mask, (req_i,))
if not_greedy:
rejected = False
start_idx = tl.get_element(start_idxs, (req_i,))
start_idx = get_element(start_idxs, (req_i,))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
@@ -219,9 +219,9 @@ def expand_kernel(
src_val = tl.where(src_val == replace_from, replace_to, src_val)
for i in tl.range(0, BLOCK_SIZE):
num_tokens1 = tl.get_element(num_tokens, (i,))
start_idx1 = tl.get_element(start_idx, (i,))
src_val1 = tl.get_element(src_val, (i,))
num_tokens1 = get_element(num_tokens, (i,))
start_idx1 = get_element(start_idx, (i,))
src_val1 = get_element(src_val, (i,))
offset1 = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1)
@@ -272,7 +272,7 @@ def sample_recovered_tokens_kernel(
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id,))
max_p = get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
@@ -297,7 +297,7 @@ def sample_recovered_tokens_kernel(
)
new_p = prob / q
recovered_id = tl.argmax(new_p, axis=-1)
max_p = tl.get_element(new_p, (recovered_id,))
max_p = get_element(new_p, (recovered_id,))
if max_p > global_max_p:
global_max_p = max_p
global_recovered_id = vocab_start + recovered_id
@@ -388,15 +388,15 @@ def rejection_random_sample_block_verify_kernel(
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
n_num_draft_tokens = end_idxs - start_idxs
for req_i in range(BLOCK_SIZE):
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
not_greedy = get_element(not_greedy_mask, (req_i,))
if not_greedy:
rejected = False
pi = 1.0
uniform_prob = 1.0
last_accepted_token_pos = -1
start_idx = tl.get_element(start_idxs, (req_i,))
start_idx = get_element(start_idxs, (req_i,))
req_idx = block_idx * BLOCK_SIZE + req_i
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
for pos in range(num_draft_tokens):
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)