[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)
### What this PR does / why we need it? This pull request refactors the dispatch mechanism for the **triton-ascend-specific operators** `insert_slice`, `extract_slice`, and `get_element` to ensure compatibility with both CANN 8.5 and 9.0. A unified helper function, `_resolve_triton_ascend_op`, has been introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function dynamically resolves these operators by first attempting to import them from the `triton.language.extra.cann.extension` module, which is present in newer CANN versions. If that fails, it falls back to the standard `triton.language` module. This approach centralizes operator dispatch logic, allowing individual Triton kernels to use these functions without being aware of the underlying Triton/CANN version. All call sites have been updated to use these new unified functions. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring of operator implementations and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with existing tests. **Testing Context:** - vLLM version: v0.16.0 - vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7` Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, insert_slice
|
||||
|
||||
from .utils import prepare_chunk_indices
|
||||
|
||||
|
||||
@@ -80,7 +82,7 @@ def solve_tril_16x16_kernel(
|
||||
|
||||
# 4 Use mask to safely load data
|
||||
b_A_subrec16 = tl.load(ptr_A_subrec16, mask=load_mask, other=0.0).to(tl.float32)
|
||||
b_A = tl.insert_slice(
|
||||
b_A = insert_slice(
|
||||
ful=b_A,
|
||||
sub=b_A_subrec16[None, :, :], # (1, 16, 16)
|
||||
offsets=[blkid, 0, 0],
|
||||
@@ -100,7 +102,7 @@ def solve_tril_16x16_kernel(
|
||||
|
||||
# for loop to update N_BLOCKS row vector
|
||||
for i in range(1, 16):
|
||||
nblks_vec16 = -tl.extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
|
||||
nblks_vec16 = -extract_slice(local_ori_A, (i, 0), (1, 16 * N_BLOCKS), (16 * N_BLOCKS, 1))
|
||||
b_a = tl.reshape(nblks_vec16, (N_BLOCKS, 16))
|
||||
|
||||
dot_tmp = tl.trans(b_a[:, :, None] * b_A, (1, 0, 2))
|
||||
@@ -108,7 +110,7 @@ def solve_tril_16x16_kernel(
|
||||
b_a = b_a + dot_product
|
||||
|
||||
b_a_new_expanded = b_a[:, None, :]
|
||||
b_A = tl.insert_slice(
|
||||
b_A = insert_slice(
|
||||
ful=b_A, sub=b_a_new_expanded, offsets=[0, i, 0], sizes=[N_BLOCKS, 1, 16], strides=[1, 1, 1]
|
||||
)
|
||||
|
||||
@@ -276,9 +278,9 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
|
||||
# build Ai_22_32 (32 * 32)
|
||||
Ai_22_32 = tl.zeros((32, 32), tl.float32)
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
|
||||
Ai_22_32 = tl.insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
|
||||
Ai_22_32 = insert_slice(Ai_22_32, Ai_33, (0, 0), (16, 16), (1, 1))
|
||||
Ai_22_32 = insert_slice(Ai_22_32, Ai_44, (16, 16), (16, 16), (1, 1))
|
||||
Ai_22_32 = insert_slice(Ai_22_32, Ai_43, (16, 0), (16, 16), (1, 1))
|
||||
|
||||
# load A_21_32 (A block at row i_t * 64 + 32, col 0, 32 * 32)
|
||||
offs_m = i_t * 64 + 32 + tl.arange(0, 32)
|
||||
@@ -290,9 +292,9 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
|
||||
# build Ai_11_32 (32 * 32)
|
||||
Ai_11_32 = tl.zeros((32, 32), tl.float32)
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
|
||||
Ai_11_32 = tl.insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
|
||||
Ai_11_32 = insert_slice(Ai_11_32, Ai_11, (0, 0), (16, 16), (1, 1))
|
||||
Ai_11_32 = insert_slice(Ai_11_32, Ai_22, (16, 16), (16, 16), (1, 1))
|
||||
Ai_11_32 = insert_slice(Ai_11_32, Ai_21, (16, 0), (16, 16), (1, 1))
|
||||
|
||||
Ai_21_32 = -tl.dot(tmp, Ai_11_32, input_precision="ieee")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user