implement batch invariant with ascendc (#6590)

### What this PR does / why we need it?
there are batch invariant ops implemented by triton and ascendc, this pr
aims to choose which kind of ops to be used to enable batch invariant.
#5487

### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd

---------

Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
Ronald
2026-02-10 14:15:26 +08:00
committed by GitHub
parent 66b60c9440
commit 77305df398
2 changed files with 200 additions and 11 deletions

View File

@@ -0,0 +1,163 @@
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
# type: ignore
import importlib
import os
import sys
from unittest.mock import MagicMock, patch
import pytest
import vllm_ascend.batch_invariant as batch_invariant
class TestBatchInvariant:
"""Complete test suite for batch_invariant.py"""
def test_override_envs_for_invariance(self):
"""Test environment variable override"""
# Clear environment variables
env_vars = ["VLLM_ASCEND_ENABLE_NZ", "HCCL_DETERMINISTIC", "LCCL_DETERMINISTIC"]
for var in env_vars:
if var in os.environ:
del os.environ[var]
# Call function
batch_invariant.override_envs_for_invariance()
# Verify environment variables
assert os.environ["VLLM_ASCEND_ENABLE_NZ"] == "0"
assert os.environ["HCCL_DETERMINISTIC"] == "strict"
assert os.environ["LCCL_DETERMINISTIC"] == "1"
@pytest.mark.parametrize("custom_ops_available, expected_value", [(True, True), (False, False)])
def test_has_ascendc_batch_invariant(self, custom_ops_available, expected_value):
"""Test HAS_ASCENDC_BATCH_INVARIANT detection"""
# Control custom_ops availability
if custom_ops_available:
sys.modules["batch_invariant_ops"] = MagicMock()
else:
sys.modules.pop("batch_invariant_ops", None)
# Reload module to re-evaluate the flag
importlib.reload(batch_invariant)
# Verify result
assert batch_invariant.HAS_ASCENDC_BATCH_INVARIANT == expected_value
@patch("vllm_ascend.batch_invariant.HAS_TRITON", False)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", True)
def test_enable_batch_invariant_mode_ascendc_path(self):
"""Test enable_batch_invariant_mode with AscendC ops available"""
# Mock dependencies
mock_library = MagicMock()
batch_invariant.torch.library.Library = MagicMock(return_value=mock_library)
batch_invariant.torch.ops.batch_invariant_ops = MagicMock()
# Call function
batch_invariant.enable_batch_invariant_mode()
# Verify library created
batch_invariant.torch.library.Library.assert_called_once_with("aten", "IMPL")
# Verify operator registrations
assert mock_library.impl.call_count == 3
mock_library.impl.assert_any_call(
"aten::mm", batch_invariant.torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU"
)
mock_library.impl.assert_any_call(
"aten::matmul", batch_invariant.torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU"
)
mock_library.impl.assert_any_call(
"aten::sum", batch_invariant.torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU"
)
# Verify torch_npu function patching
assert (
batch_invariant.torch_npu.npu_fused_infer_attention_score
== batch_invariant.torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant
)
@patch("vllm_ascend.batch_invariant.HAS_TRITON", True)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False)
def test_enable_batch_invariant_mode_triton_path(self):
"""Test enable_batch_invariant_mode with only Triton available"""
# Mock dependencies
mock_library = MagicMock()
batch_invariant.torch.library.Library = MagicMock(return_value=mock_library)
# Mock triton imports
batch_invariant.addmm_batch_invariant = MagicMock()
batch_invariant.bmm_batch_invariant = MagicMock()
batch_invariant.mm_batch_invariant = MagicMock()
batch_invariant.matmul_batch_invariant = MagicMock()
batch_invariant.linear_batch_invariant = MagicMock()
# Call function
batch_invariant.enable_batch_invariant_mode()
# Verify operator registrations
assert mock_library.impl.call_count == 5
mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::bmm", batch_invariant.bmm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::mm", batch_invariant.mm_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::matmul", batch_invariant.matmul_batch_invariant, "NPU")
mock_library.impl.assert_any_call("aten::linear", batch_invariant.linear_batch_invariant, "NPU")
@patch("vllm_ascend.batch_invariant.HAS_TRITON", False)
@patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False)
def test_enable_batch_invariant_mode_no_backend(self):
"""Test enable_batch_invariant_mode with no backends available"""
# Mock library
mock_library = MagicMock()
batch_invariant.torch.library.Library = MagicMock(return_value=mock_library)
# Call function
batch_invariant.enable_batch_invariant_mode()
# Verify no operators registered
mock_library.impl.assert_not_called()
@pytest.mark.parametrize(
"batch_invariant_enabled, has_backend, expected_logger_call",
[(True, True, "info"), (True, False, "warning"), (False, True, None), (False, False, None)],
)
def test_init_batch_invariance(self, batch_invariant_enabled, has_backend, expected_logger_call):
"""Test init_batch_invariance under different conditions"""
# Mock dependencies
batch_invariant.vllm_is_batch_invariant = MagicMock(return_value=batch_invariant_enabled)
batch_invariant.HAS_TRITON = has_backend
batch_invariant.HAS_ASCENDC_BATCH_INVARIANT = has_backend
batch_invariant.override_envs_for_invariance = MagicMock()
batch_invariant.enable_batch_invariant_mode = MagicMock()
# Call function
batch_invariant.init_batch_invariance()
# Verify function calls based on conditions
if batch_invariant_enabled and has_backend:
batch_invariant.override_envs_for_invariance.assert_called_once()
batch_invariant.enable_batch_invariant_mode.assert_called_once()
elif batch_invariant_enabled and not has_backend:
batch_invariant.override_envs_for_invariance.assert_not_called()
batch_invariant.enable_batch_invariant_mode.assert_not_called()
else:
batch_invariant.override_envs_for_invariance.assert_not_called()
batch_invariant.enable_batch_invariant_mode.assert_not_called()
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -19,6 +19,7 @@
import os
import torch
import torch_npu
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.triton_utils import HAS_TRITON
@@ -35,15 +36,21 @@ if HAS_TRITON:
)
def override_envs_for_invariance():
# TODO(Ronald) set attntion backend to deterministic mode
try:
import batch_invariant_ops # type: ignore[import-not-found] # noqa
HAS_ASCENDC_BATCH_INVARIANT = True
except ImportError:
HAS_ASCENDC_BATCH_INVARIANT = False
def override_envs_for_invariance():
# enabling NZ mode introduces NZ format input to the triton operator,
# resulting in accuracy anomalies.
os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0"
# communication determinism settings
os.environ["HCCL_DETERMINISTIC"] = "true"
os.environ["HCCL_DETERMINISTIC"] = "strict"
os.environ["LCCL_DETERMINISTIC"] = "1"
@@ -52,14 +59,32 @@ _batch_invariant_LIB = None
def enable_batch_invariant_mode():
global _batch_invariant_LIB
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
# Register operators only implemented in triton.
if HAS_TRITON:
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU")
# Register operators implemented in Ascend batch-invariant ops in priority.
if HAS_ASCENDC_BATCH_INVARIANT:
_batch_invariant_LIB.impl("aten::mm", torch.ops.batch_invariant_ops.npu_mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", torch.ops.batch_invariant_ops.npu_matmul_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::sum", torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant, "NPU")
# torch_npu.npu_fused_infer_attention_score is a function of torch_npu, not a torch.ops.Operator,
# so we need to patch it directly.
torch_npu.npu_fused_infer_attention_score = (
torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant
)
# register triton implementations if ascendc is not available.
elif HAS_TRITON:
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "NPU")
_batch_invariant_LIB.impl("aten::matmul", matmul_batch_invariant, "NPU")
# linear call matmul internally, so register linear only when ascendc
# is not available. it will get better performance with ascendc.
_batch_invariant_LIB.impl("aten::linear", linear_batch_invariant, "NPU")
def init_batch_invariance():
@@ -75,7 +100,7 @@ def init_batch_invariance():
environment variable to enable automatically.
"""
if vllm_is_batch_invariant():
if HAS_TRITON:
if HAS_TRITON or HAS_ASCENDC_BATCH_INVARIANT:
logger.info(
"Enabling batch-invariant mode for vLLM on Ascend NPU.",
)
@@ -83,5 +108,6 @@ def init_batch_invariance():
enable_batch_invariant_mode()
else:
logger.warning(
"Batch-invariant mode requested but Triton is not available.skipping batch-invariant initialization.",
"Batch-invariant mode requested but Triton or AscendC batch-invariant "
"ops is not available.skipping batch-invariant initialization."
)