[300I][Bugfix] fix unquant model weight nd2nz error (#6851)

### What this PR does / why we need it?
- This PR fixes an issue with weight format conversion for unquantized
models running on Ascend 310P devices.

- The changes refactor the logic for converting weights to the
FRACTAL_NZ format. Previously, this was handled in a 310P-specific
linear layer implementation (`AscendUnquantizedLinearMethod310`). This
implementation has been removed, and the logic is now centralized in the
`maybe_trans_nz` utility function. This function now checks if the
device is a 310P and applies the NZ format cast accordingly for
`float16`/`bfloat16` weights.

- This refactoring simplifies the code by removing platform-specific
duplication and ensures correct weight handling for unquantized models
on 310P.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
ut and local test
- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

---------

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
Shaoxu Cheng
2026-03-03 15:57:26 +08:00
committed by GitHub
parent f19f7b1fe2
commit 2064afe380
8 changed files with 214 additions and 89 deletions

View File

@@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend._310p.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod310
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
@@ -50,7 +50,7 @@ class TestAscendModelSlimConfig310(TestBase):
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True),
):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, AscendUnquantizedLinearMethod310)
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
# Test quantized layer
mock_scheme = MagicMock()

View File

@@ -249,3 +249,91 @@ class TestUtils(TestBase):
utils.register_ascend_customop()
self.assertEqual(mock_customop.register_oot.call_count,
len(REGISTERED_ASCEND_OPS))
@mock.patch("torch_npu.npu_format_cast")
def test_maybe_trans_nz(self, mock_npu_format_cast):
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
mock_npu_format_cast.side_effect = lambda weight, fmt: weight
def assert_nz_cast(weight):
mock_npu_format_cast.assert_called_once()
args, kwargs = mock_npu_format_cast.call_args
self.assertIs(args[0], weight)
self.assertEqual(args[1], ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(kwargs, {})
# Test case 1: non-310P, NZ is disabled
with (
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}),
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
):
weight = torch.randn(32, 64, dtype=torch.float16)
result = utils.maybe_trans_nz(weight)
self.assertIs(result, weight)
mock_npu_format_cast.assert_not_called()
# Test case 2: 310P always converts non-fp32 weights, even when NZ=0
mock_npu_format_cast.reset_mock()
with (
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "0"}),
mock.patch("vllm_ascend.utils.is_310p", return_value=True),
):
weight = torch.randn(32, 64, dtype=torch.float16)
result = utils.maybe_trans_nz(weight)
self.assertIs(result, weight)
assert_nz_cast(weight)
# Test case 3: fp32 never converts, including on 310P
mock_npu_format_cast.reset_mock()
with (
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}),
mock.patch("vllm_ascend.utils.is_310p", return_value=True),
):
weight = torch.randn(32, 64, dtype=torch.float32)
result = utils.maybe_trans_nz(weight)
self.assertIs(result, weight)
mock_npu_format_cast.assert_not_called()
# Test case 4: non-310P fp16 converts only when NZ=2
mock_npu_format_cast.reset_mock()
with (
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}),
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
):
weight = torch.randn(32, 64, dtype=torch.float16)
result = utils.maybe_trans_nz(weight)
self.assertIs(result, weight)
mock_npu_format_cast.assert_not_called()
# Test case 5: non-310P fp16 converts when NZ=2
mock_npu_format_cast.reset_mock()
with (
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}),
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
):
weight = torch.randn(32, 64, dtype=torch.float16)
result = utils.maybe_trans_nz(weight)
self.assertIs(result, weight)
assert_nz_cast(weight)
# Test case 6: non-310P bf16 converts when NZ=2
mock_npu_format_cast.reset_mock()
with (
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "2"}),
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
):
weight = torch.randn(32, 64, dtype=torch.bfloat16)
result = utils.maybe_trans_nz(weight)
self.assertIs(result, weight)
assert_nz_cast(weight)
# Test case 7: non-310P quantized weights still convert by default
mock_npu_format_cast.reset_mock()
with (
mock.patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_NZ": "1"}),
mock.patch("vllm_ascend.utils.is_310p", return_value=False),
):
weight = torch.zeros(32, 64, dtype=torch.int8)
result = utils.maybe_trans_nz(weight)
self.assertIs(result, weight)
assert_nz_cast(weight)

View File

@@ -1,65 +0,0 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn as nn
import torch_npu
from vllm.model_executor.layers.linear import (
LinearBase,
QuantizeMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
class AscendUnquantizedLinearMethod310(UnquantizedLinearMethod):
def process_weights_after_loading(self, layer: nn.Module) -> None:
super().process_weights_after_loading(layer)
if "conv1d" not in getattr(layer, "prefix", ""):
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
class AscendLinearBase310(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: object | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
nn.Module.__init__(self)
self.input_size = int(input_size)
self.output_size = int(output_size)
self.skip_bias_add = skip_bias_add
self.params_dtype = torch.float16
self.quant_config = quant_config
self.prefix = prefix
self.return_bias = return_bias
self.disable_tp = disable_tp
if quant_config is None:
self.quant_method: QuantizeMethodBase | None = AscendUnquantizedLinearMethod310()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)

View File

@@ -0,0 +1,82 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn.functional as F
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
UnquantizedEmbeddingMethod,
)
from vllm_ascend.ops.vocab_parallel_embedding import AscendParallelLMHead, AscendVocabParallelEmbedding
from vllm_ascend.utils import maybe_trans_nz
class AscendUnquantizedEmbeddingMethod310(UnquantizedEmbeddingMethod):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.weight_nz = maybe_trans_nz(layer.weight)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return F.linear(x, layer.weight_nz, bias)
class AscendVocabParallelEmbedding310(AscendVocabParallelEmbedding):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
params_dtype: torch.dtype | None = None,
org_num_embeddings: int | None = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__(
num_embeddings, embedding_dim, params_dtype, org_num_embeddings, padding_size, quant_config, prefix
)
if quant_config is None:
self.quant_method = AscendUnquantizedEmbeddingMethod310()
class AscendParallelLMHead310(AscendParallelLMHead):
"""
Register ParallelLMHead as a custom op for Atlas 310p.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: torch.dtype | None = None,
org_num_embeddings: int | None = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__(
num_embeddings, embedding_dim, bias, params_dtype, org_num_embeddings, padding_size, quant_config, prefix
)
if quant_config is None:
self.quant_method = AscendUnquantizedEmbeddingMethod310()

View File

@@ -21,7 +21,7 @@ import torch
import torch_npu
from vllm_ascend.quantization.methods.base import AscendLinearScheme
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from vllm_ascend.utils import maybe_trans_nz
from .registry import register_scheme
@@ -105,7 +105,7 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
).to(layer.aclnn_input_scale.dtype)
# ---- matmul stage tensor ----
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ).transpose(0, 1)
layer.weight.data = maybe_trans_nz(layer.weight.data).transpose(0, 1)
# ---- dequant stage tensors ----
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)

View File

@@ -21,7 +21,7 @@ import torch
import torch_npu
from vllm_ascend.quantization.methods.base import AscendLinearScheme
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from vllm_ascend.utils import maybe_trans_nz
from .registry import register_scheme
@@ -84,4 +84,4 @@ class AscendW8A8SLinearMethod310(AscendLinearScheme):
layer.aclnn_input_scale = layer.input_scale.data.repeat(expanding_factor)
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data
layer.aclnn_input_offset = layer.input_offset.data.repeat(expanding_factor).to(layer.aclnn_input_scale.dtype)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
layer.weight.data = maybe_trans_nz(layer.weight.data)

View File

@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
UnquantizedEmbeddingMethod,
VocabParallelEmbedding,
)
@@ -104,9 +103,9 @@ class AscendModelSlimConfig310(AscendModelSlimConfig):
if isinstance(layer, LinearBase):
packed = getattr(self, "packed_modules_mapping", {})
if self.is_layer_skipped_ascend(prefix, packed):
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
return AscendUnquantizedLinearMethod310()
return AscendUnquantizedLinearMethod()
scheme = create_scheme_for_layer(
quant_description=self.quant_description,
@@ -125,6 +124,8 @@ class AscendModelSlimConfig310(AscendModelSlimConfig):
return AscendFusedMoEMethod(scheme, layer.moe_config)
elif isinstance(layer, VocabParallelEmbedding):
return UnquantizedEmbeddingMethod()
from vllm_ascend._310p.ops.vocab_parallel_embedding import AscendUnquantizedEmbeddingMethod310
return AscendUnquantizedEmbeddingMethod310()
return super().get_quant_method(layer, prefix)

View File

@@ -134,21 +134,34 @@ def _unregister_print_streams_on_exit():
atexit.register(_unregister_print_streams_on_exit)
def maybe_trans_nz(weight: torch.Tensor):
def _should_trans_nz(weight: torch.Tensor) -> bool:
# FP32 cannot use NZ.
if weight.dtype == torch.float32:
return False
# 310P always converts to NZ.
if is_310p():
return True
# NZ is disabled on non-310P.
if not envs_ascend.VLLM_ASCEND_ENABLE_NZ:
# NZ is not enabled
return False
# BF16/FP16 convert only when enable_nz == 2.
if weight.dtype in {torch.bfloat16, torch.float16}:
return envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2
# Quantized or other supported dtypes convert by default.
return True
# NZ conversion policy:
# - 310P: always convert supported weights to FRACTAL_NZ
# - non-310P: follow VLLM_ASCEND_ENABLE_NZ
# - FP32: never convert
def maybe_trans_nz(weight: torch.Tensor) -> torch.Tensor:
if not _should_trans_nz(weight):
return weight
if weight.dtype == torch.float:
# fp32 can not support NZ
return weight
elif weight.dtype in {torch.bfloat16, torch.float16}:
# bf16/fp16 will trans nz when VLLM_ASCEND_ENABLE_NZ is 2
if envs_ascend.VLLM_ASCEND_ENABLE_NZ == 2:
return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ)
else:
return weight
else:
# quant weight will trans nz by default
return torch_npu.npu_format_cast(weight, ACL_FORMAT_FRACTAL_NZ)
@@ -631,6 +644,10 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
from vllm_ascend._310p.ops.activation import AscendSiluAndMul310
from vllm_ascend._310p.ops.layernorm import AscendGemmaRMSNorm310, AscendRMSNorm310
from vllm_ascend._310p.ops.rotary_embedding import AscendRotaryEmbedding310
from vllm_ascend._310p.ops.vocab_parallel_embedding import (
AscendParallelLMHead310,
AscendVocabParallelEmbedding310,
)
REGISTERED_ASCEND_OPS.update(
{
@@ -640,6 +657,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
"GemmaRMSNorm": AscendGemmaRMSNorm310,
"FusedMoE": AscendFusedMoE310,
"SharedFusedMoE": AscendSharedFusedMoE310,
"ParallelLMHead": AscendParallelLMHead310,
"VocabParallelEmbedding": AscendVocabParallelEmbedding310,
}
)