[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)
### What this PR does / why we need it? This pull request refactors the dispatch mechanism for the **triton-ascend-specific operators** `insert_slice`, `extract_slice`, and `get_element` to ensure compatibility with both CANN 8.5 and 9.0. A unified helper function, `_resolve_triton_ascend_op`, has been introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function dynamically resolves these operators by first attempting to import them from the `triton.language.extra.cann.extension` module, which is present in newer CANN versions. If that fails, it falls back to the standard `triton.language` module. This approach centralizes operator dispatch logic, allowing individual Triton kernels to use these functions without being aware of the underlying Triton/CANN version. All call sites have been updated to use these new unified functions. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring of operator implementations and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with existing tests. **Testing Context:** - vLLM version: v0.16.0 - vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7` Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import torch
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -40,8 +40,8 @@ def _swiglu_quant_kernel(
|
||||
# swiglu
|
||||
x_offsets = row_idx * TOTAL_COLS + tl.arange(0, TOTAL_COLS)
|
||||
cur_x = tl.load(x_ptr + x_offsets)
|
||||
x1 = tl.extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
|
||||
x2 = tl.extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
|
||||
x1 = extract_slice(cur_x, offsets=(0,), sizes=(HALF_COLS,), strides=(1,))
|
||||
x2 = extract_slice(cur_x, offsets=(HALF_COLS,), sizes=(HALF_COLS,), strides=(1,))
|
||||
out = x1 * tl.sigmoid(x1) * x2
|
||||
|
||||
# quant
|
||||
@@ -50,7 +50,7 @@ def _swiglu_quant_kernel(
|
||||
# store scale
|
||||
tl.store(scale_ptr + row_idx, scale.to(scale_ptr.dtype.element_ty))
|
||||
for col_blk_idx in range(0, HALF_COLS, COL_BLOCK_SIZE):
|
||||
tmp_out = tl.extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
|
||||
tmp_out = extract_slice(out, offsets=(col_blk_idx,), sizes=(COL_BLOCK_SIZE,), strides=(1,))
|
||||
tmp_out = (tmp_out.to(tl.float32) / scale).to(x_ptr.dtype.element_ty)
|
||||
tmp_out = tmp_out.cast(tl.int8, overflow_mode="saturate")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user