[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)
### What this PR does / why we need it? This pull request refactors the dispatch mechanism for the **triton-ascend-specific operators** `insert_slice`, `extract_slice`, and `get_element` to ensure compatibility with both CANN 8.5 and 9.0. A unified helper function, `_resolve_triton_ascend_op`, has been introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function dynamically resolves these operators by first attempting to import them from the `triton.language.extra.cann.extension` module, which is present in newer CANN versions. If that fails, it falls back to the standard `triton.language` module. This approach centralizes operator dispatch logic, allowing individual Triton kernels to use these functions without being aware of the underlying Triton/CANN version. All call sites have been updated to use these new unified functions. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring of operator implementations and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with existing tests. **Testing Context:** - vLLM version: v0.16.0 - vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7` Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num
|
||||
|
||||
|
||||
def cal_grid_and_block_size(batch_size: int):
|
||||
@@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton(
|
||||
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
|
||||
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
draft_token_id1 = tl.get_element(draft_token_id, (pos,))
|
||||
target_argmax1 = tl.get_element(target_argmax_id, (pos,))
|
||||
draft_token_id1 = get_element(draft_token_id, (pos,))
|
||||
target_argmax1 = get_element(target_argmax_id, (pos,))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
if draft_token_id1 == target_argmax1:
|
||||
bonus_renew_1(
|
||||
@@ -109,10 +109,10 @@ def rejection_greedy_sample_triton(
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_draft_tokens, (pos,))
|
||||
num_tokens1 = get_element(num_draft_tokens, (pos,))
|
||||
rejected = False
|
||||
start_idx1 = tl.get_element(start_idx, (pos,))
|
||||
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos,))
|
||||
start_idx1 = get_element(start_idx, (pos,))
|
||||
is_greedy_mask1 = get_element(is_greedy_mask, (pos,))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
for i in range(num_tokens1):
|
||||
if not rejected:
|
||||
@@ -162,12 +162,12 @@ def rejection_random_sample_kernel(
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
|
||||
not_greedy = get_element(not_greedy_mask, (req_i,))
|
||||
if not_greedy:
|
||||
rejected = False
|
||||
start_idx = tl.get_element(start_idxs, (req_i,))
|
||||
start_idx = get_element(start_idxs, (req_i,))
|
||||
req_idx = block_idx * BLOCK_SIZE + req_i
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
|
||||
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
@@ -219,9 +219,9 @@ def expand_kernel(
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
|
||||
for i in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_tokens, (i,))
|
||||
start_idx1 = tl.get_element(start_idx, (i,))
|
||||
src_val1 = tl.get_element(src_val, (i,))
|
||||
num_tokens1 = get_element(num_tokens, (i,))
|
||||
start_idx1 = get_element(start_idx, (i,))
|
||||
src_val1 = get_element(src_val, (i,))
|
||||
offset1 = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1)
|
||||
|
||||
@@ -272,7 +272,7 @@ def sample_recovered_tokens_kernel(
|
||||
)
|
||||
new_p = prob / q
|
||||
recovered_id = tl.argmax(new_p, axis=-1)
|
||||
max_p = tl.get_element(new_p, (recovered_id,))
|
||||
max_p = get_element(new_p, (recovered_id,))
|
||||
if max_p > global_max_p:
|
||||
global_max_p = max_p
|
||||
global_recovered_id = vocab_start + recovered_id
|
||||
@@ -297,7 +297,7 @@ def sample_recovered_tokens_kernel(
|
||||
)
|
||||
new_p = prob / q
|
||||
recovered_id = tl.argmax(new_p, axis=-1)
|
||||
max_p = tl.get_element(new_p, (recovered_id,))
|
||||
max_p = get_element(new_p, (recovered_id,))
|
||||
if max_p > global_max_p:
|
||||
global_max_p = max_p
|
||||
global_recovered_id = vocab_start + recovered_id
|
||||
@@ -388,15 +388,15 @@ def rejection_random_sample_block_verify_kernel(
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
|
||||
not_greedy = get_element(not_greedy_mask, (req_i,))
|
||||
if not_greedy:
|
||||
rejected = False
|
||||
pi = 1.0
|
||||
uniform_prob = 1.0
|
||||
last_accepted_token_pos = -1
|
||||
start_idx = tl.get_element(start_idxs, (req_i,))
|
||||
start_idx = get_element(start_idxs, (req_i,))
|
||||
req_idx = block_idx * BLOCK_SIZE + req_i
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
|
||||
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
|
||||
|
||||
for pos in range(num_draft_tokens):
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
|
||||
Reference in New Issue
Block a user