[Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)
### What this PR does / why we need it? This pull request refactors the dispatch mechanism for the **triton-ascend-specific operators** `insert_slice`, `extract_slice`, and `get_element` to ensure compatibility with both CANN 8.5 and 9.0. A unified helper function, `_resolve_triton_ascend_op`, has been introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function dynamically resolves these operators by first attempting to import them from the `triton.language.extra.cann.extension` module, which is present in newer CANN versions. If that fails, it falls back to the standard `triton.language` module. This approach centralizes operator dispatch logic, allowing individual Triton kernels to use these functions without being aware of the underlying Triton/CANN version. All call sites have been updated to use these new unified functions. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactoring of operator implementations and does not introduce any user-facing changes. ### How was this patch tested? CI is expected to pass with existing tests. **Testing Context:** - vLLM version: v0.16.0 - vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7` Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -20,7 +20,7 @@ import triton # type: ignore
|
||||
import triton.language as tl # type: ignore
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
|
||||
from vllm_ascend.ops.triton.triton_utils import extract_slice, get_vectorcore_num, insert_slice
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -79,13 +79,13 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
@@ -95,14 +95,14 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
roped_q2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_q = tl.zeros((Q_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q1,
|
||||
offsets=(0, 0),
|
||||
sizes=(Q_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_q = tl.insert_slice(
|
||||
roped_q = insert_slice(
|
||||
roped_q,
|
||||
roped_q2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
@@ -145,13 +145,13 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
sin_offsets = pos_idx * HEAD_DIM + tl.arange(HALF_HEAD_DIM, HEAD_DIM)
|
||||
cos = (tl.load(cos_sin_ptr + cos_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
sin = (tl.load(cos_sin_ptr + sin_offsets)).reshape(1, HALF_HEAD_DIM)
|
||||
x1 = tl.extract_slice(
|
||||
x1 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
x2 = tl.extract_slice(
|
||||
x2 = extract_slice(
|
||||
normalized_values,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
@@ -161,14 +161,14 @@ def split_qkv_rmsnorm_rope_kernel(
|
||||
roped_k2 = x2 * cos + x1 * sin
|
||||
|
||||
roped_k = tl.zeros((KV_BLOCK_SIZE // HEAD_DIM, HEAD_DIM), dtype=tl.bfloat16)
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k1,
|
||||
offsets=(0, 0),
|
||||
sizes=(KV_BLOCK_SIZE // HEAD_DIM, HALF_HEAD_DIM),
|
||||
strides=(1, 1),
|
||||
)
|
||||
roped_k = tl.insert_slice(
|
||||
roped_k = insert_slice(
|
||||
roped_k,
|
||||
roped_k2,
|
||||
offsets=(0, HALF_HEAD_DIM),
|
||||
|
||||
Reference in New Issue
Block a user