Files
xc-llm-ascend/vllm_ascend/ops/triton/triton_utils.py
linfeng-yuan 700423156f [Triton] Centralize Ascend extension op dispatch in triton_utils (#6937)
### What this PR does / why we need it?

This pull request refactors the dispatch mechanism for the
**triton-ascend-specific operators** `insert_slice`, `extract_slice`,
and `get_element` to ensure compatibility with both CANN 8.5 and 9.0.

A unified helper function, `_resolve_triton_ascend_op`, has been
introduced in `vllm_ascend/ops/triton/triton_utils.py`. This function
dynamically resolves these operators by first attempting to import them
from the `triton.language.extra.cann.extension` module, which is present
in newer CANN versions. If that fails, it falls back to the standard
`triton.language` module.

This approach centralizes operator dispatch logic, allowing individual
Triton kernels to use these functions without being aware of the
underlying Triton/CANN version. All call sites have been updated to use
these new unified functions.

### Does this PR introduce _any_ user-facing change?

No. This is an internal refactoring of operator implementations and does
not introduce any user-facing changes.

### How was this patch tested?

CI is expected to pass with existing tests.

**Testing Context:**
- vLLM version: v0.16.0
- vLLM main: `15d76f74e2fdb12a95ea00f0ca283acf6219a2b7`

Signed-off-by: linfeng-yuan <1102311262@qq.com>
2026-03-03 17:10:30 +08:00

67 lines
2.0 KiB
Python

from typing import Any
import torch
from vllm.triton_utils import HAS_TRITON, tl, triton
_NUM_AICORE = -1
_NUM_VECTORCORE = -1
_extension_module = None
if HAS_TRITON:
try:
import triton.language.extra.cann.extension as _extension_module # type: ignore
except ImportError:
_extension_module = None
def _resolve_triton_ascend_op(op_name: str):
if not HAS_TRITON:
raise RuntimeError(f"Triton op '{op_name}' cannot be resolved because HAS_TRITON is False")
if _extension_module is not None:
extension_op = getattr(_extension_module, op_name, None)
if extension_op is not None:
return extension_op
tl_op = getattr(tl, op_name, None)
if tl_op is not None:
return tl_op
raise RuntimeError(
f"Failed to resolve Triton op '{op_name}': "
"neither triton.language.extra.cann.extension nor triton.language provides it."
)
if HAS_TRITON:
insert_slice = _resolve_triton_ascend_op("insert_slice")
extract_slice = _resolve_triton_ascend_op("extract_slice")
get_element = _resolve_triton_ascend_op("get_element")
else:
insert_slice = None
extract_slice = None
get_element = None
def init_device_properties_triton():
global _NUM_AICORE, _NUM_VECTORCORE
if _NUM_AICORE == -1 and HAS_TRITON:
device_properties: dict[str, Any] = triton.runtime.driver.active.utils.get_device_properties(
torch.npu.current_device()
)
_NUM_AICORE = device_properties.get("num_aicore", -1)
_NUM_VECTORCORE = device_properties.get("num_vectorcore", -1)
assert _NUM_AICORE > 0 and _NUM_VECTORCORE > 0, "Failed to detect device properties."
def get_aicore_num():
global _NUM_AICORE
assert _NUM_AICORE > 0, "Device properties not initialized. Please call init_device_properties_triton() first."
return _NUM_AICORE
def get_vectorcore_num():
global _NUM_VECTORCORE
assert _NUM_VECTORCORE > 0, "Device properties not initialized. Please call init_device_properties_triton() first."
return _NUM_VECTORCORE