From 77e009d9fc0aaf814cf9a7557d6d8f6f62384185 Mon Sep 17 00:00:00 2001 From: Ronald Date: Thu, 5 Mar 2026 09:12:40 +0800 Subject: [PATCH] [Feature] Add docs of batch invariance and make some extra operators patch (#6910) ### What this PR does / why we need it? This PR add docs of batch invariance and make some extra operators according to validation result. please see https://github.com/vllm-project/vllm-ascend/issues/5487 to track progress. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: Ronald1995 --- .../feature_guide/batch_invariance.md | 136 ++++++++++++++++++ docs/source/user_guide/feature_guide/index.md | 1 + .../batch_invariant/test_batch_invariant.py | 97 ++++++++++--- vllm_ascend/ascend_config.py | 9 +- vllm_ascend/batch_invariant.py | 38 +++++ vllm_ascend/sample/sampler.py | 5 + vllm_ascend/utils.py | 9 ++ 7 files changed, 276 insertions(+), 19 deletions(-) create mode 100644 docs/source/user_guide/feature_guide/batch_invariance.md diff --git a/docs/source/user_guide/feature_guide/batch_invariance.md b/docs/source/user_guide/feature_guide/batch_invariance.md new file mode 100644 index 00000000..9dd74295 --- /dev/null +++ b/docs/source/user_guide/feature_guide/batch_invariance.md @@ -0,0 +1,136 @@ +# Batch Invariance + +```{note} +Batch invariance is currently in beta. Some features are still under active development. +Track progress and planned improvements at +``` + +This document shows how to enable batch invariance in vLLM-Ascend. Batch invariance ensures that the output of a model is deterministic and independent of the batch size or the order of requests in a batch. + +## Motivation + +Batch invariance is crucial for several use cases: + +- **Framework debugging**: Deterministic outputs make it easier to debug issues in the inference framework, as the same input will always produce the same output regardless of batching. +- **Model debugging**: Helps identify issues in model implementations by ensuring consistent behavior across different batch configurations. +- **Reinforcement Learning (RL)**: RL training often requires deterministic rollouts for reproducibility and stable training. +- **Large-scale inference systems**: Systems that use vLLM as a component benefit from deterministic behavior for testing, validation, and consistency guarantees. + +## Hardware Requirements + +Batch invariance currently requires Ascend NPUs for 910B, +because only 910B supports batch invariance with HCCL communication for now, +we will support other NPUs in the future. + +## Software Requirements + +Batch invariance requires a customed operator library for 910B. +We will release the customed operator library in future versions. + +## Enabling Batch Invariance + +Batch invariance can be enabled by setting the `VLLM_BATCH_INVARIANT` environment variable to `1`: + +```bash +export VLLM_BATCH_INVARIANT=1 +``` + +### Online Inference (Server Mode) + +To start a vLLM server with batch invariance enabled: + +```bash +VLLM_BATCH_INVARIANT=1 vllm serve Qwen/Qwen3-8B +``` + +Then use the OpenAI-compatible client: + +```python +from openai import OpenAI + +client = OpenAI( + api_key="EMPTY", + base_url="http://localhost:8000/v1", +) + +# These requests will produce deterministic outputs +# regardless of batch size or order +response = client.completions.create( + model="Qwen/Qwen3-8B", + prompt="The future of AI is", + max_tokens=100, + temperature=0.7, + seed=42, +) + +print(response.choices[0].text) +``` + +### Offline Inference + +For offline batch inference with batch invariance: + +```python +import os +os.environ["VLLM_BATCH_INVARIANT"] = "1" + +from vllm import LLM, SamplingParams + +prompts = [ + "The future of AI is", + "Machine learning enables", + "Deep learning models can", +] + +sampling_params = SamplingParams( + temperature=0.7, + max_tokens=100, + seed=42, +) + +llm = LLM( + model="Qwen/Qwen3-8B", + tensor_parallel_size=1, +) + +# Outputs will be deterministic regardless of batch size +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}") + print(f"Generated: {generated_text!r}\n") +``` + +## Tested Models + +Batch invariance has been tested and verified on the following models: + +- **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B` +- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B` + +Other models may also work, but these have been explicitly validated. If you encounter issues with a specific model, please report them on the [GitHub issue tracker](https://github.com/vllm-project/vllm-ascend/issues/new/choose). + +## Implementation Details + +When batch invariance is enabled, vLLM: + +1. Uses deterministic kernel implementations for attention and other operations +2. Ensures consistent numerical behavior across different batch sizes +3. Disables certain optimizations that may introduce non-determinism + +```{note} +Enabling batch invariance may impact performance compared to the default non-deterministic mode. This trade-off is intentional to guarantee reproducibility. +``` + +## Future Improvements + +The batch invariance feature is under active development. Planned improvements include: + +- Support for additional NPUs series +- Expanded model coverage +- Performance optimizations +- Additional testing and validation + +For the latest status and to contribute ideas, see the [tracking issue](https://github.com/vllm-project/vllm-ascend/issues/5487). diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index ef0d2555..ad7abdcc 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -25,4 +25,5 @@ context_parallel npugraph_ex weight_prefetch sequence_parallelism +batch_invariance ::: diff --git a/tests/ut/batch_invariant/test_batch_invariant.py b/tests/ut/batch_invariant/test_batch_invariant.py index 55fe02e9..6e1bfdda 100644 --- a/tests/ut/batch_invariant/test_batch_invariant.py +++ b/tests/ut/batch_invariant/test_batch_invariant.py @@ -14,13 +14,13 @@ # This file is a part of the vllm-ascend project. # # type: ignore -import importlib import os -import sys from unittest.mock import MagicMock, patch import pytest +import torch +# Now import the module under test import vllm_ascend.batch_invariant as batch_invariant @@ -43,21 +43,6 @@ class TestBatchInvariant: assert os.environ["HCCL_DETERMINISTIC"] == "strict" assert os.environ["LCCL_DETERMINISTIC"] == "1" - @pytest.mark.parametrize("custom_ops_available, expected_value", [(True, True), (False, False)]) - def test_has_ascendc_batch_invariant(self, custom_ops_available, expected_value): - """Test HAS_ASCENDC_BATCH_INVARIANT detection""" - # Control custom_ops availability - if custom_ops_available: - sys.modules["batch_invariant_ops"] = MagicMock() - else: - sys.modules.pop("batch_invariant_ops", None) - - # Reload module to re-evaluate the flag - importlib.reload(batch_invariant) - - # Verify result - assert batch_invariant.HAS_ASCENDC_BATCH_INVARIANT == expected_value - @patch("vllm_ascend.batch_invariant.HAS_TRITON", False) @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", True) def test_enable_batch_invariant_mode_ascendc_path(self): @@ -105,17 +90,20 @@ class TestBatchInvariant: batch_invariant.mm_batch_invariant = MagicMock() batch_invariant.matmul_batch_invariant = MagicMock() batch_invariant.linear_batch_invariant = MagicMock() + batch_invariant.softmax_batch_invariant = MagicMock() # Call function batch_invariant.enable_batch_invariant_mode() # Verify operator registrations - assert mock_library.impl.call_count == 5 + assert mock_library.impl.call_count == 7 mock_library.impl.assert_any_call("aten::addmm", batch_invariant.addmm_batch_invariant, "NPU") mock_library.impl.assert_any_call("aten::bmm", batch_invariant.bmm_batch_invariant, "NPU") mock_library.impl.assert_any_call("aten::mm", batch_invariant.mm_batch_invariant, "NPU") mock_library.impl.assert_any_call("aten::matmul", batch_invariant.matmul_batch_invariant, "NPU") mock_library.impl.assert_any_call("aten::linear", batch_invariant.linear_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::softmax", batch_invariant.softmax_batch_invariant, "NPU") + mock_library.impl.assert_any_call("aten::_softmax", batch_invariant.softmax_batch_invariant, "NPU") @patch("vllm_ascend.batch_invariant.HAS_TRITON", False) @patch("vllm_ascend.batch_invariant.HAS_ASCENDC_BATCH_INVARIANT", False) @@ -158,6 +146,79 @@ class TestBatchInvariant: batch_invariant.override_envs_for_invariance.assert_not_called() batch_invariant.enable_batch_invariant_mode.assert_not_called() + @patch("vllm_ascend.batch_invariant.torch_npu") + def test_add_rms_norm(self, mock_torch_npu): + """Test add_rms_norm function""" + # Mock dependencies + mock_torch = batch_invariant.torch + + # Create mock tensors + batch_size = 2 + hidden_size = 4 + x = MagicMock(spec=torch.Tensor) + residual = MagicMock(spec=torch.Tensor) + weight = MagicMock(spec=torch.Tensor) + eps = 1e-6 + + # Set up mock return value for addition + x_plus_residual = MagicMock(spec=torch.Tensor) + x.__add__.return_value = x_plus_residual + + # Set up expected outputs from npu_rms_norm + expected_output = MagicMock(spec=torch.Tensor) + expected_residual = MagicMock(spec=torch.Tensor) + mock_torch_npu.npu_rms_norm.return_value = (expected_output, expected_residual) + + # Call the function + result_x, result_placeholder, result_residual = batch_invariant.add_rms_norm(x, residual, weight, eps) + + # Verify the addition was called + x.__add__.assert_called_once_with(residual) + + # Verify the npu_rms_norm was called with the correct parameters + mock_torch_npu.npu_rms_norm.assert_called_once_with(x_plus_residual, weight, eps) + + # Verify the results + assert result_x is expected_output + assert result_placeholder is None + + @patch("vllm_ascend.batch_invariant.torch_npu") + def test_add_rms_norm_consistency(self, mock_torch_npu): + """Test that add_rms_norm produces the same output as torch_npu.npu_add_rms_norm""" + # Create mock tensors + batch_size = 2 + hidden_size = 4 + x = MagicMock(spec=torch.Tensor) + residual = MagicMock(spec=torch.Tensor) + weight = MagicMock(spec=torch.Tensor) + eps = 1e-6 + + # Set up mock values + x_plus_residual = MagicMock(spec=torch.Tensor) + x.__add__.return_value = x_plus_residual + + # Define consistent mock results + expected_output = MagicMock(spec=torch.Tensor) + expected_residual = MagicMock(spec=torch.Tensor) + + # Set up mock_npu_rms_norm to return the same results as if it were npu_add_rms_norm + mock_torch_npu.npu_rms_norm.return_value = (expected_output, expected_residual) + mock_torch_npu.npu_add_rms_norm.return_value = (expected_output, None, expected_residual) + + # Call add_rms_norm + add_rms_norm_result = batch_invariant.add_rms_norm(x, residual, weight, eps) + + # Call npu_add_rms_norm directly + npu_add_rms_norm_result = mock_torch_npu.npu_add_rms_norm(x, residual, weight, eps) + + # Verify both functions return the same results + assert add_rms_norm_result[0] == npu_add_rms_norm_result[0] + + # Verify the function composition is correct + x.__add__.assert_called_once_with(residual) + mock_torch_npu.npu_rms_norm.assert_called_once_with(x_plus_residual, weight, eps) + mock_torch_npu.npu_add_rms_norm.assert_called_once_with(x, residual, weight, eps) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index e524ad76..03a40763 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -124,7 +124,14 @@ class AscendConfig: # npu_fused_infer_attention_score performs better on all scenarios. self.pa_shape_list = additional_config.get("pa_shape_list", []) - self.enable_async_exponential = bool(additional_config.get("enable_async_exponential", False)) + # when enable_async_exponential is True, AscendSampler will be different from vllm Sampler, + # which make batch_invariant mode not working. + # so we disable async exponential when batch_invariant mode is enabled. + from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + + self.enable_async_exponential = ( + bool(additional_config.get("enable_async_exponential", False)) and not vllm_is_batch_invariant() + ) self.enable_kv_nz = additional_config.get("enable_kv_nz", False) if self.enable_kv_nz: diff --git a/vllm_ascend/batch_invariant.py b/vllm_ascend/batch_invariant.py index 38b6ad7d..e7521a80 100644 --- a/vllm_ascend/batch_invariant.py +++ b/vllm_ascend/batch_invariant.py @@ -24,6 +24,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.triton_utils import HAS_TRITON +# in case recursive call in reduce_sum. +torch_sum = torch.sum + logger = init_logger(__name__) if HAS_TRITON: @@ -34,6 +37,7 @@ if HAS_TRITON: matmul_batch_invariant, mm_batch_invariant, ) + from vllm_ascend.ops.triton.batch_invariant.softmax import softmax_batch_invariant try: @@ -44,10 +48,38 @@ except ImportError: HAS_ASCENDC_BATCH_INVARIANT = False +def add_rms_norm( + x: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + eps: float, +): + """AclnnAddRmsNorm can't ensure batch invariant, + so we need to split it into add and rms_norm. + """ + x_ = x + residual + residual_ = x_ + x_, _ = torch_npu.npu_rms_norm(x_, weight, eps) + return x_, None, residual_ + + +def reduce_sum(x: torch.Tensor, dim: int | None = None, keepdim: bool = False) -> torch.Tensor: + """npu_reduce_sum_batch_invariant requires dim to be specified, but torch.sum + doesn't require it, so we set dim to -1 by default if dim is None and x.dim()==1. + """ + dim = -1 if dim is None and x.dim() == 1 else dim + if x.device.type == "npu" and dim is not None: + return torch.ops.batch_invariant_ops.npu_reduce_sum_batch_invariant(x, dim, keepdim) + # cpu tensor can't use npu_reduce_sum_batch_invariant, so we use torch.sum instead. + return torch_sum(x, dim, keepdim) + + def override_envs_for_invariance(): # enabling NZ mode introduces NZ format input to the triton operator, # resulting in accuracy anomalies. os.environ["VLLM_ASCEND_ENABLE_NZ"] = "0" + # fused operator can't ensure batch invariant, so we disable it. + os.environ["VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE"] = "0" # communication determinism settings os.environ["HCCL_DETERMINISTIC"] = "strict" @@ -65,6 +97,8 @@ def enable_batch_invariant_mode(): if HAS_TRITON: _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "NPU") _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::softmax", softmax_batch_invariant, "NPU") + _batch_invariant_LIB.impl("aten::_softmax", softmax_batch_invariant, "NPU") # Register operators implemented in Ascend batch-invariant ops in priority. if HAS_ASCENDC_BATCH_INVARIANT: @@ -76,6 +110,10 @@ def enable_batch_invariant_mode(): torch_npu.npu_fused_infer_attention_score = ( torch.ops.batch_invariant_ops.npu_fused_infer_attention_score_batch_invariant ) + # patch npu_add_rms_norm to ensure batch invariant. + torch_npu.npu_add_rms_norm = add_rms_norm + # torch.sum can't be replaced by dispatch logic, so we patch it directly. + torch.sum = reduce_sum # register triton implementations if ascendc is not available. elif HAS_TRITON: diff --git a/vllm_ascend/sample/sampler.py b/vllm_ascend/sample/sampler.py index 40cc496d..5a73f951 100644 --- a/vllm_ascend/sample/sampler.py +++ b/vllm_ascend/sample/sampler.py @@ -1,4 +1,5 @@ import torch +from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from vllm.v1.sample.sampler import Sampler @@ -73,6 +74,10 @@ class AscendTopKTopPSampler(TopKTopPSampler): def forward_native(self, logits, generators, k, p): """Override pytorch native implementation to torch_npu""" + # when batch_invariant mode is enabled, we should use vllm's implementation. + # or it will make batch_invariant mode not working. + if vllm_is_batch_invariant(): + return super().forward_native(logits, generators, k, p) logits = self.apply_top_k_top_p(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index a45834b6..424e4fc0 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -258,10 +258,19 @@ def enable_custom_op(): Enable lazy init for vllm_ascend_C to avoid early initialization of CANN's RTS component. Ensure that ASCEND_RT_VISIBLE_DEVICES can be dynamically modified before torch.npu.set_device(). """ + from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + global _CUSTOM_OP_ENABLED if _CUSTOM_OP_ENABLED is not None: return _CUSTOM_OP_ENABLED + + # There are some customed operators which aren't implemented + # with batch invariant in vllm-ascend, we need to disable them. + if vllm_is_batch_invariant(): + _CUSTOM_OP_ENABLED = False + return _CUSTOM_OP_ENABLED + try: # isort: off # register custom ops into torch_library here