implement batch invariant with ascendc (#6590)

### What this PR does / why we need it?
there are batch invariant ops implemented by triton and ascendc, this pr
aims to choose which kind of ops to be used to enable batch invariant.
#5487

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-02-10 14:15:26 +08:00
committed by GitHub
parent 66b60c9440
commit 77305df398
2 changed files with 200 additions and 11 deletions

View File

@@ -19,6 +19,7 @@
import os
import torch
import torch_npu
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.triton_utils import HAS_TRITON
@@ -35,15 +36,21 @@ if HAS_TRITON:
)
def override_envs_for_invariance():
# TODO(Ronald) set attntion backend to deterministic mode
try:
import batch_invariant_ops # type: ignore[import-not-found] # noqa
HAS_ASCENDC_BATCH_INVARIANT = True
except ImportError:
HAS_ASCENDC_BATCH_INVARIANT = False
def override_envs_for_invariance():
# enabling NZ mode introduces NZ format input to the triton operator,
# resulting in accuracy anomalies.
os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0"
# communication determinism settings
os.environ["HCCL_DETERMINISTIC"] = "true"
os.environ["HCCL_DETERMINISTIC"] = "strict"
os.environ["LCCL_DETERMINISTIC"] = "1"
@@ -52,14 +59,32 @@ _batch_invariant_LIB = None
def enable_batch_invariant_mode():
global _batch_invariant_LIB
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
# Register operators only implemented in triton.
if HAS_TRITON:
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
# Register operators implemented in Ascend batch-invariant ops in priority.
if HAS_ASCENDC_BATCH_INVARIANT:
_batch_invariant_LIB.impl("aten::mm", torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::sum", torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU")
# torch_npu.npu_fused_infer_attention_score is a function of torch_npu, not a torch.ops.Operator,
# so we need to patch it directly.
torch_npu.npu_fused_infer_attention_score = (
torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant
)
# register triton implementations if ascendc is not available.
elif HAS_TRITON:
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")
# linear call matmul internally, so register linear only when ascendc
# is not available. it will get better performance with ascendc.
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")
def init_batch_invariance():
@@ -75,7 +100,7 @@ def init_batch_invariance():
environment variable to enable automatically.
"""
if vllm_is_batch_invariant():
if HAS_TRITON:
if HAS_TRITON or HAS_ASCENDC_BATCH_INVARIANT:
logger.info(
"Enabling batch-invariant mode for vLLM on Ascend NPU.",
)
@@ -83,5 +108,6 @@ def init_batch_invariance():
enable_batch_invariant_mode()
else:
logger.warning(
"Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.",
"Batch-invariant mode requested but Triton or AscendC batch-invariant "
"ops is not available.skipping batch-invariant initialization."
)