[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)
### What this PR does / why we need it? This pull request refactors the dispatch mechanism for the **triton-ascend-specific operators** `insert_slice`, `extract_slice`, and `get_element` to ensure compatibility with both CANN 8.5 and 9.0. A unified helper function, `_resolve_triton_ascend_op`, has been introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function dynamically resolves these operators by first attempting to import them from the `triton.language.extra.cann.extension` module, which is present in newer CANN versions. If that fails, it falls back to the standard `triton.language` module. This approach centralizes operator dispatch logic, allowing individual Triton kernels to use these functions without being aware of the underlying Triton/CANN version. All call sites have been updated to use these new unified functions. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring of operator implementations and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with existing tests. **Testing Context:** - vLLM version: v0.16.0 - vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7` Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -40,8 +40,8 @@ def _swiglu_quant_kernel(
|
||||
# swiglu
|
||||
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
|
||||
cur_x = tl.load(x_ptr + x_offsets)
|
||||
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
|
||||
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
|
||||
x1 = extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
|
||||
x2 = extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
|
||||
out = x1 * tl.sigmoid(x1) * x2
|
||||
|
||||
# quant
|
||||
@@ -50,7 +50,7 @@ def _swiglu_quant_kernel(
|
||||
# store scale
|
||||
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
|
||||
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
|
||||
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
|
||||
tmp_out = extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
|
||||
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
|
||||
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice
|
||||
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@@ -80,7 +82,7 @@ def solve_tril_16x16_kernel(
|
||||
|
||||
# 4 Use mask to safely load data
|
||||
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
|
||||
b_A = tl.insert_slice(
|
||||
b_A = insert_slice(
|
||||
ful=b_A,
|
||||
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
|
||||
offsets=[blkid, 0, 0],
|
||||
@@ -100,7 +102,7 @@ def solve_tril_16x16_kernel(
|
||||
|
||||
# for loop to update N_BLOCKS row vector
|
||||
for i in range(1, 16):
|
||||
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
|
||||
nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
|
||||
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
|
||||
|
||||
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
|
||||
@@ -108,7 +110,7 @@ def solve_tril_16x16_kernel(
|
||||
b_a = b_a + dot_product
|
||||
|
||||
b_a_new_expanded = b_a[:, None, :]
|
||||
b_A = tl.insert_slice(
|
||||
b_A = insert_slice(
|
||||
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
|
||||
)
|
||||
|
||||
@@ -276,9 +278,9 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
|
||||
# build Ai_22_32 (32 * 32)
|
||||
Ai_22_32 = tl.zeros((32, 32), tl.float32)
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
|
||||
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
|
||||
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
|
||||
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
|
||||
|
||||
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
|
||||
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
||||
@@ -290,9 +292,9 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
|
||||
# build Ai_11_32 (32 * 32)
|
||||
Ai_11_32 = tl.zeros((32, 32), tl.float32)
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
|
||||
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
|
||||
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
|
||||
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
|
||||
|
||||
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import triton # type: ignore
|
||||
import triton.language as tl # type: ignore
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -79,13 +79,13 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
@@ -95,14 +95,14 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
roped_q2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q1,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
@@ -145,13 +145,13 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
@@ -161,14 +161,14 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
roped_k2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k1,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.ops.triton.triton_utils import get_element, get_vectorcore_num
|
||||
|
||||
|
||||
def cal_grid_and_block_size(batch_size: int):
|
||||
@@ -59,8 +59,8 @@ def rejection_greedy_sample_spec_len_1_triton(
|
||||
tl.store(output_token_ids_ptr + offset * 2, target_argmax_id, mask)
|
||||
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
draft_token_id1 = tl.get_element(draft_token_id, (pos,))
|
||||
target_argmax1 = tl.get_element(target_argmax_id, (pos,))
|
||||
draft_token_id1 = get_element(draft_token_id, (pos,))
|
||||
target_argmax1 = get_element(target_argmax_id, (pos,))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
if draft_token_id1 == target_argmax1:
|
||||
bonus_renew_1(
|
||||
@@ -109,10 +109,10 @@ def rejection_greedy_sample_triton(
|
||||
num_draft_tokens = end_idx - start_idx
|
||||
|
||||
for pos in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_draft_tokens, (pos,))
|
||||
num_tokens1 = get_element(num_draft_tokens, (pos,))
|
||||
rejected = False
|
||||
start_idx1 = tl.get_element(start_idx, (pos,))
|
||||
is_greedy_mask1 = tl.get_element(is_greedy_mask, (pos,))
|
||||
start_idx1 = get_element(start_idx, (pos,))
|
||||
is_greedy_mask1 = get_element(is_greedy_mask, (pos,))
|
||||
position = block_idx * BLOCK_SIZE + pos
|
||||
for i in range(num_tokens1):
|
||||
if not rejected:
|
||||
@@ -162,12 +162,12 @@ def rejection_random_sample_kernel(
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
|
||||
not_greedy = get_element(not_greedy_mask, (req_i,))
|
||||
if not_greedy:
|
||||
rejected = False
|
||||
start_idx = tl.get_element(start_idxs, (req_i,))
|
||||
start_idx = get_element(start_idxs, (req_i,))
|
||||
req_idx = block_idx * BLOCK_SIZE + req_i
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
|
||||
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
|
||||
for pos in range(num_draft_tokens):
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
@@ -219,9 +219,9 @@ def expand_kernel(
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
|
||||
for i in tl.range(0, BLOCK_SIZE):
|
||||
num_tokens1 = tl.get_element(num_tokens, (i,))
|
||||
start_idx1 = tl.get_element(start_idx, (i,))
|
||||
src_val1 = tl.get_element(src_val, (i,))
|
||||
num_tokens1 = get_element(num_tokens, (i,))
|
||||
start_idx1 = get_element(start_idx, (i,))
|
||||
src_val1 = get_element(src_val, (i,))
|
||||
offset1 = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx1 + offset1, src_val1, mask=offset1 < num_tokens1)
|
||||
|
||||
@@ -272,7 +272,7 @@ def sample_recovered_tokens_kernel(
|
||||
)
|
||||
new_p = prob / q
|
||||
recovered_id = tl.argmax(new_p, axis=-1)
|
||||
max_p = tl.get_element(new_p, (recovered_id,))
|
||||
max_p = get_element(new_p, (recovered_id,))
|
||||
if max_p > global_max_p:
|
||||
global_max_p = max_p
|
||||
global_recovered_id = vocab_start + recovered_id
|
||||
@@ -297,7 +297,7 @@ def sample_recovered_tokens_kernel(
|
||||
)
|
||||
new_p = prob / q
|
||||
recovered_id = tl.argmax(new_p, axis=-1)
|
||||
max_p = tl.get_element(new_p, (recovered_id,))
|
||||
max_p = get_element(new_p, (recovered_id,))
|
||||
if max_p > global_max_p:
|
||||
global_max_p = max_p
|
||||
global_recovered_id = vocab_start + recovered_id
|
||||
@@ -388,15 +388,15 @@ def rejection_random_sample_block_verify_kernel(
|
||||
end_idxs = tl.load(cu_num_draft_tokens_ptr + offsets, not_greedy_mask)
|
||||
n_num_draft_tokens = end_idxs - start_idxs
|
||||
for req_i in range(BLOCK_SIZE):
|
||||
not_greedy = tl.get_element(not_greedy_mask, (req_i,))
|
||||
not_greedy = get_element(not_greedy_mask, (req_i,))
|
||||
if not_greedy:
|
||||
rejected = False
|
||||
pi = 1.0
|
||||
uniform_prob = 1.0
|
||||
last_accepted_token_pos = -1
|
||||
start_idx = tl.get_element(start_idxs, (req_i,))
|
||||
start_idx = get_element(start_idxs, (req_i,))
|
||||
req_idx = block_idx * BLOCK_SIZE + req_i
|
||||
num_draft_tokens = tl.get_element(n_num_draft_tokens, (req_i,))
|
||||
num_draft_tokens = get_element(n_num_draft_tokens, (req_i,))
|
||||
|
||||
for pos in range(num_draft_tokens):
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
|
||||
@@ -1,10 +1,46 @@
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
|
||||
_NUM_AICORE = -1
|
||||
_NUM_VECTORCORE = -1
|
||||
_extension_module = None
|
||||
|
||||
if HAS_TRITON:
|
||||
try:
|
||||
import triton.language.extra.cann.extension as _extension_module # type: ignore
|
||||
except ImportError:
|
||||
_extension_module = None
|
||||
|
||||
|
||||
def _resolve_triton_ascend_op(op_name: str):
|
||||
if not HAS_TRITON:
|
||||
raise RuntimeError(f"Triton op '{op_name}' cannot be resolved because HAS_TRITON is False")
|
||||
|
||||
if _extension_module is not None:
|
||||
extension_op = getattr(_extension_module, op_name, None)
|
||||
if extension_op is not None:
|
||||
return extension_op
|
||||
|
||||
tl_op = getattr(tl, op_name, None)
|
||||
if tl_op is not None:
|
||||
return tl_op
|
||||
|
||||
raise RuntimeError(
|
||||
f"Failed to resolve Triton op '{op_name}': "
|
||||
"neither triton.language.extra.cann.extension nor triton.language provides it."
|
||||
)
|
||||
|
||||
|
||||
if HAS_TRITON:
|
||||
insert_slice = _resolve_triton_ascend_op("insert_slice")
|
||||
extract_slice = _resolve_triton_ascend_op("extract_slice")
|
||||
get_element = _resolve_triton_ascend_op("get_element")
|
||||
else:
|
||||
insert_slice = None
|
||||
extract_slice = None
|
||||
get_element = None
|
||||
|
||||
|
||||
def init_device_properties_triton():
|
||||
|
||||
Reference in New Issue
Block a user