[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)

### What this PR does / why we need it?

This pull request refactors the dispatch mechanism for the
**triton-ascend-specific operators** `insert_slice`, `extract_slice`,
and `get_element` to ensure compatibility with both CANN 8.5 and 9.0.

A unified helper function, `_resolve_triton_ascend_op`, has been
introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function
dynamically resolves these operators by first attempting to import them
from the `triton.language.extra.cann.extension` module, which is present
in newer CANN versions. If that fails, it falls back to the standard
`triton.language` module.

This approach centralizes operator dispatch logic, allowing individual
Triton kernels to use these functions without being aware of the
underlying Triton/CANN version. All call sites have been updated to use
these new unified functions.

### Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring of operator implementations and does
not introduce any user-facing changes.

### How was this patch tested?

CI is expected to pass with existing tests.

**Testing Context:**
- vLLM version: v0.16.0
- vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7`

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-03 17:10:30 +08:00
committed by GitHub
parent cb893bcdb0
commit 700423156f
5 changed files with 78 additions and 40 deletions

View File

@@ -1,7 +1,7 @@
import torch
from vllm.triton_utils import tl, triton
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num
@triton.jit
@@ -40,8 +40,8 @@ def _swiglu_quant_kernel(
# swiglu
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
cur_x = tl.load(x_ptr + x_offsets)
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
x1 = extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
x2 = extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
out = x1 * tl.sigmoid(x1) * x2
# quant
@@ -50,7 +50,7 @@ def _swiglu_quant_kernel(
# store scale
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")