From 2064afe38054c5441b3c5778314fc810bd1eafa6 Mon Sep 17 00:00:00 2001 From: Shaoxu Cheng <2906339855@qq.com> Date: Tue, 3 Mar 2026 15:57:26 +0800 Subject: [PATCH] [300I][Bugfix] fix unquant model weight nd2nz error (#6851) ### What this PR does / why we need it? - This PR fixes an issue with weight format conversion for unquantized models running on Ascend 310P devices. - The changes refactor the logic for converting weights to the FRACTAL_NZ format. Previously, this was handled in a 310P-specific linear layer implementation (`AscendUnquantizedLinearMethod310`). This implementation has been removed, and the logic is now centralized in the `maybe_trans_nz` utility function. This function now checks if the device is a 310P and applies the NZ format cast accordingly for `float16`/`bfloat16` weights. - This refactoring simplifies the code by removing platform-specific duplication and ensures correct weight handling for unquantized models on 310P. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ut and local test - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83b47f67b1dfad505606070ae4d9f83e50ad4ebd --------- Signed-off-by: Tflowers-0129 <2906339855@qq.com> --- .../quantization/test_modelslim_config_310.py | 4 +- tests/ut/test_utils.py | 88 +++++++++++++++++++ vllm_ascend/_310p/ops/linear.py | 65 -------------- .../_310p/ops/vocab_parallel_embedding.py | 82 +++++++++++++++++ .../_310p/quantization/methods/w8a8_static.py | 4 +- .../_310p/quantization/methods/w8a8s.py | 4 +- .../_310p/quantization/modelslim_config.py | 9 +- vllm_ascend/utils.py | 47 +++++++--- 8 files changed, 214 insertions(+), 89 deletions(-) delete mode 100644 vllm_ascend/_310p/ops/linear.py create mode 100644 vllm_ascend/_310p/ops/vocab_parallel_embedding.py diff --git a/tests/ut/_310p/quantization/test_modelslim_config_310.py b/tests/ut/_310p/quantization/test_modelslim_config_310.py index 6f38b145..7e614e21 100644 --- a/tests/ut/_310p/quantization/test_modelslim_config_310.py +++ b/tests/ut/_310p/quantization/test_modelslim_config_310.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import LinearBase from tests.ut.base import TestBase from vllm_ascend._310p.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod310 -from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310 +from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310 @@ -50,7 +50,7 @@ class TestAscendModelSlimConfig310(TestBase): patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True), ): method = self.ascend_config.get_quant_method(linear_layer, ".attn") - self.assertIsInstance(method, AscendUnquantizedLinearMethod310) + self.assertIsInstance(method, AscendUnquantizedLinearMethod) # Test quantized layer mock_scheme = MagicMock() diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index f90502c3..6f4c2500 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -249,3 +249,91 @@ class TestUtils(TestBase): utils.register_ascend_customop() self.assertEqual(mock_customop.register_oot.call_count, len(REGISTERED_ASCEND_OPS)) + + @mock.patch("torch_npu.npu_format_cast") + def test_maybe_trans_nz(self, mock_npu_format_cast): + from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ + mock_npu_format_cast.side_effect = lambda weight, fmt: weight + + def assert_nz_cast(weight): + mock_npu_format_cast.assert_called_once() + args, kwargs = mock_npu_format_cast.call_args + self.assertIs(args[0], weight) + self.assertEqual(args[1], ACL_FORMAT_FRACTAL_NZ) + self.assertEqual(kwargs, {}) + + # Test case 1: non-310P, NZ is disabled + with ( + mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}), + mock.patch("vllm_ascend.utils.is_310p", return_value=False), + ): + weight = torch.randn(32, 64, dtype=torch.float16) + result = utils.maybe_trans_nz(weight) + self.assertIs(result, weight) + mock_npu_format_cast.assert_not_called() + + # Test case 2: 310P always converts non-fp32 weights, even when NZ=0 + mock_npu_format_cast.reset_mock() + with ( + mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}), + mock.patch("vllm_ascend.utils.is_310p", return_value=True), + ): + weight = torch.randn(32, 64, dtype=torch.float16) + result = utils.maybe_trans_nz(weight) + self.assertIs(result, weight) + assert_nz_cast(weight) + + # Test case 3: fp32 never converts, including on 310P + mock_npu_format_cast.reset_mock() + with ( + mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}), + mock.patch("vllm_ascend.utils.is_310p", return_value=True), + ): + weight = torch.randn(32, 64, dtype=torch.float32) + result = utils.maybe_trans_nz(weight) + self.assertIs(result, weight) + mock_npu_format_cast.assert_not_called() + + # Test case 4: non-310P fp16 converts only when NZ=2 + mock_npu_format_cast.reset_mock() + with ( + mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}), + mock.patch("vllm_ascend.utils.is_310p", return_value=False), + ): + weight = torch.randn(32, 64, dtype=torch.float16) + result = utils.maybe_trans_nz(weight) + self.assertIs(result, weight) + mock_npu_format_cast.assert_not_called() + + # Test case 5: non-310P fp16 converts when NZ=2 + mock_npu_format_cast.reset_mock() + with ( + mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}), + mock.patch("vllm_ascend.utils.is_310p", return_value=False), + ): + weight = torch.randn(32, 64, dtype=torch.float16) + result = utils.maybe_trans_nz(weight) + self.assertIs(result, weight) + assert_nz_cast(weight) + + # Test case 6: non-310P bf16 converts when NZ=2 + mock_npu_format_cast.reset_mock() + with ( + mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}), + mock.patch("vllm_ascend.utils.is_310p", return_value=False), + ): + weight = torch.randn(32, 64, dtype=torch.bfloat16) + result = utils.maybe_trans_nz(weight) + self.assertIs(result, weight) + assert_nz_cast(weight) + + # Test case 7: non-310P quantized weights still convert by default + mock_npu_format_cast.reset_mock() + with ( + mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}), + mock.patch("vllm_ascend.utils.is_310p", return_value=False), + ): + weight = torch.zeros(32, 64, dtype=torch.int8) + result = utils.maybe_trans_nz(weight) + self.assertIs(result, weight) + assert_nz_cast(weight) diff --git a/vllm_ascend/_310p/ops/linear.py b/vllm_ascend/_310p/ops/linear.py deleted file mode 100644 index e3043cec..00000000 --- a/vllm_ascend/_310p/ops/linear.py +++ /dev/null @@ -1,65 +0,0 @@ -# -# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import torch -import torch.nn as nn -import torch_npu -from vllm.model_executor.layers.linear import ( - LinearBase, - QuantizeMethodBase, - UnquantizedLinearMethod, -) -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig - -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ - - -class AscendUnquantizedLinearMethod310(UnquantizedLinearMethod): - def process_weights_after_loading(self, layer: nn.Module) -> None: - super().process_weights_after_loading(layer) - if "conv1d" not in getattr(layer, "prefix", ""): - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ) - - -class AscendLinearBase310(LinearBase): - def __init__( - self, - input_size: int, - output_size: int, - skip_bias_add: bool = False, - params_dtype: object | None = None, - quant_config: QuantizationConfig | None = None, - prefix: str = "", - *, - return_bias: bool = True, - disable_tp: bool = False, - ): - nn.Module.__init__(self) - - self.input_size = int(input_size) - self.output_size = int(output_size) - self.skip_bias_add = skip_bias_add - self.params_dtype = torch.float16 - self.quant_config = quant_config - self.prefix = prefix - self.return_bias = return_bias - self.disable_tp = disable_tp - - if quant_config is None: - self.quant_method: QuantizeMethodBase | None = AscendUnquantizedLinearMethod310() - else: - self.quant_method = quant_config.get_quant_method(self, prefix=prefix) diff --git a/vllm_ascend/_310p/ops/vocab_parallel_embedding.py b/vllm_ascend/_310p/ops/vocab_parallel_embedding.py new file mode 100644 index 00000000..97438a57 --- /dev/null +++ b/vllm_ascend/_310p/ops/vocab_parallel_embedding.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.nn.functional as F +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + UnquantizedEmbeddingMethod, +) + +from vllm_ascend.ops.vocab_parallel_embedding import AscendParallelLMHead, AscendVocabParallelEmbedding +from vllm_ascend.utils import maybe_trans_nz + + +class AscendUnquantizedEmbeddingMethod310(UnquantizedEmbeddingMethod): + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight_nz = maybe_trans_nz(layer.weight) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return F.linear(x, layer.weight_nz, bias) + + +class AscendVocabParallelEmbedding310(AscendVocabParallelEmbedding): + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__( + num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix + ) + if quant_config is None: + self.quant_method = AscendUnquantizedEmbeddingMethod310() + + +class AscendParallelLMHead310(AscendParallelLMHead): + """ + Register ParallelLMHead as a custom op for Atlas 310p. + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: torch.dtype | None = None, + org_num_embeddings: int | None = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__( + num_embeddings, embedding_dim, bias, params_dtype, org_num_embeddings, padding_size, quant_config, prefix + ) + + if quant_config is None: + self.quant_method = AscendUnquantizedEmbeddingMethod310() diff --git a/vllm_ascend/_310p/quantization/methods/w8a8_static.py b/vllm_ascend/_310p/quantization/methods/w8a8_static.py index 403e5df7..6cd65cdc 100644 --- a/vllm_ascend/_310p/quantization/methods/w8a8_static.py +++ b/vllm_ascend/_310p/quantization/methods/w8a8_static.py @@ -21,7 +21,7 @@ import torch import torch_npu from vllm_ascend.quantization.methods.base import AscendLinearScheme -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.utils import maybe_trans_nz from .registry import register_scheme @@ -105,7 +105,7 @@ class AscendW8A8LinearMethod310(AscendLinearScheme): ).to(layer.aclnn_input_scale.dtype) # ---- matmul stage tensor ---- - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ).transpose(0, 1) + layer.weight.data = maybe_trans_nz(layer.weight.data).transpose(0, 1) # ---- dequant stage tensors ---- layer.weight_scale.data = torch.flatten(layer.weight_scale.data) diff --git a/vllm_ascend/_310p/quantization/methods/w8a8s.py b/vllm_ascend/_310p/quantization/methods/w8a8s.py index 80e09472..b102f30f 100644 --- a/vllm_ascend/_310p/quantization/methods/w8a8s.py +++ b/vllm_ascend/_310p/quantization/methods/w8a8s.py @@ -21,7 +21,7 @@ import torch import torch_npu from vllm_ascend.quantization.methods.base import AscendLinearScheme -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ +from vllm_ascend.utils import maybe_trans_nz from .registry import register_scheme @@ -84,4 +84,4 @@ class AscendW8A8SLinearMethod310(AscendLinearScheme): layer.aclnn_input_scale = layer.input_scale.data.repeat(expanding_factor) layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data layer.aclnn_input_offset = layer.input_offset.data.repeat(expanding_factor).to(layer.aclnn_input_scale.dtype) - layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.weight.data = maybe_trans_nz(layer.weight.data) diff --git a/vllm_ascend/_310p/quantization/modelslim_config.py b/vllm_ascend/_310p/quantization/modelslim_config.py index 91c98dea..5e9b0abf 100644 --- a/vllm_ascend/_310p/quantization/modelslim_config.py +++ b/vllm_ascend/_310p/quantization/modelslim_config.py @@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization import register_quantization_config from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.vocab_parallel_embedding import ( - UnquantizedEmbeddingMethod, VocabParallelEmbedding, ) @@ -104,9 +103,9 @@ class AscendModelSlimConfig310(AscendModelSlimConfig): if isinstance(layer, LinearBase): packed = getattr(self, "packed_modules_mapping", {}) if self.is_layer_skipped_ascend(prefix, packed): - from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310 + from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod - return AscendUnquantizedLinearMethod310() + return AscendUnquantizedLinearMethod() scheme = create_scheme_for_layer( quant_description=self.quant_description, @@ -125,6 +124,8 @@ class AscendModelSlimConfig310(AscendModelSlimConfig): return AscendFusedMoEMethod(scheme, layer.moe_config) elif isinstance(layer, VocabParallelEmbedding): - return UnquantizedEmbeddingMethod() + from vllm_ascend._310p.ops.vocab_parallel_embedding import AscendUnquantizedEmbeddingMethod310 + + return AscendUnquantizedEmbeddingMethod310() return super().get_quant_method(layer, prefix) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index f6139559..de0e5d12 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -134,22 +134,35 @@ def _unregister_print_streams_on_exit(): atexit.register(_unregister_print_streams_on_exit) -def maybe_trans_nz(weight: torch.Tensor): +def _should_trans_nz(weight: torch.Tensor) -> bool: + # FP32 cannot use NZ. + if weight.dtype == torch.float32: + return False + + # 310P always converts to NZ. + if is_310p(): + return True + + # NZ is disabled on non-310P. if not envs_ascend.VLLM_ASCEND_ENABLE_NZ: - # NZ is not enabled + return False + + # BF16/FP16 convert only when enable_nz == 2. + if weight.dtype in {torch.bfloat16, torch.float16}: + return envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2 + + # Quantized or other supported dtypes convert by default. + return True + + +# NZ conversion policy: +# - 310P: always convert supported weights to FRACTAL_NZ +# - non-310P: follow VLLM_ASCEND_ENABLE_NZ +# - FP32: never convert +def maybe_trans_nz(weight: torch.Tensor) -> torch.Tensor: + if not _should_trans_nz(weight): return weight - if weight.dtype == torch.float: - # fp32 can not support NZ - return weight - elif weight.dtype in {torch.bfloat16, torch.float16}: - # bf16/fp16 will trans nz when VLLM_ASCEND_ENABLE_NZ is 2 - if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2: - return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ) - else: - return weight - else: - # quant weight will trans nz by default - return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ) + return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ) def _round_up(x: int, align: int): @@ -631,6 +644,10 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): from vllm_ascend._310p.ops.activation import AscendSiluAndMul310 from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310 from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310 + from vllm_ascend._310p.ops.vocab_parallel_embedding import ( + AscendParallelLMHead310, + AscendVocabParallelEmbedding310, + ) REGISTERED_ASCEND_OPS.update( { @@ -640,6 +657,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None): "GemmaRMSNorm": AscendGemmaRMSNorm310, "FusedMoE": AscendFusedMoE310, "SharedFusedMoE": AscendSharedFusedMoE310, + "ParallelLMHead": AscendParallelLMHead310, + "VocabParallelEmbedding": AscendVocabParallelEmbedding310, } )