[Feat.][310P]: weightNZ feature with quant or unquant. (#6705)

NZ Format Support for Linear Layers: Implemented support for the NZ
(N-dimensional Z-order) format for linear layer weights on Ascend 310P,
enhancing performance for both quantized and unquantized layers.
Unquantized Linear Method for Ascend 310P: Introduced
AscendUnquantizedLinearMethod310 to specifically handle and apply NZ
format casting to unquantized linear layer weights during the loading
process.
MRotaryEmbedding Integration: Extended Rotary Embedding support by
adding AscendMRotaryEmbedding310 to provide an Ascend-specific
implementation for MRotaryEmbedding.
Quantization Method Updates: Updated the w8a8_static quantization method
to directly transpose weights and apply NZ format casting, ensuring
consistency with the new format.
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
Shaoxu Cheng
2026-02-13 15:41:02 +08:00
committed by GitHub
parent f40256b697
commit b6bc3d2f9d
7 changed files with 144 additions and 17 deletions

View File

@@ -21,8 +21,8 @@ from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend._310p.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod310
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
class TestAscendModelSlimConfig310(TestBase):
@@ -50,7 +50,7 @@ class TestAscendModelSlimConfig310(TestBase):
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True),
):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
self.assertIsInstance(method, AscendUnquantizedLinearMethod310)
# Test quantized layer
mock_scheme = MagicMock()

View File

@@ -44,6 +44,7 @@ class TestAscendW8A8LinearMethod310(TestBase):
self.assertEqual(params["deq_scale"].dtype, torch.int64)
self.assertEqual(params["weight_scale"].dtype, torch.float16)
self.assertEqual(params["weight_offset"].dtype, torch.float16)
self.assertEqual(params["quant_bias"].shape, (10,))
self.assertEqual(params["deq_scale"].shape, (10,))
self.assertEqual(params["weight_scale"].shape, (10, 1))
@@ -71,12 +72,23 @@ class TestAscendW8A8LinearMethod310(TestBase):
output = self.method.apply(layer, x, tp_rank=0)
mock_quantize.assert_called_with(
x, layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset
x,
layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
mock_npu_quant_matmul.assert_called_with(
expect_x_output, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
)
# The bias is added by the linear layer's forward pass, not the quant method.
mock_npu_quant_matmul.assert_called_once()
(args, kwargs) = mock_npu_quant_matmul.call_args
# positional args
self.assertTrue(torch.equal(args[0], expect_x_output))
self.assertTrue(torch.equal(args[1], layer.weight.data))
self.assertTrue(torch.equal(args[2], layer.deq_scale))
# kwargs
self.assertTrue(torch.equal(kwargs["bias"], layer.quant_bias))
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
self.assertTrue(torch.equal(output, expected_y_output))
@patch("torch.ops.vllm.quantize")
@@ -98,8 +110,41 @@ class TestAscendW8A8LinearMethod310(TestBase):
output = self.method.apply(layer, x, tp_rank=0)
mock_quantize.assert_not_called()
mock_npu_quant_matmul.assert_called_with(
x, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
)
# The bias is added by the linear layer's forward pass, not the quant method.
mock_npu_quant_matmul.assert_called_once()
(args, kwargs) = mock_npu_quant_matmul.call_args
self.assertTrue(torch.equal(args[0], x))
self.assertTrue(torch.equal(args[1], layer.weight.data))
self.assertTrue(torch.equal(args[2], layer.deq_scale))
self.assertTrue(torch.equal(kwargs["bias"], layer.quant_bias))
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
self.assertTrue(torch.equal(output, expected_y_output))
@patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_calls_nz_format_cast_310p(self, mock_npu_format_cast):
mock_npu_format_cast.side_effect = lambda x, fmt: x
layer = MagicMock()
# Attributes used by process_weights_after_loading()
layer.weight = MagicMock()
layer.input_scale = MagicMock()
layer.input_offset = MagicMock()
layer.weight_scale = MagicMock()
layer.weight_offset = MagicMock()
layer.w2_weight_offset = MagicMock()
layer.weight.data = torch.randint(-127, 128, (128, 256), dtype=torch.int8)
layer.input_scale.data = torch.tensor([0.1], dtype=torch.float16)
layer.input_offset.data = torch.tensor([0], dtype=torch.int8)
layer.weight_scale.data = torch.randn(128, 1, dtype=torch.bfloat16)
layer.weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
# w2_weight_offset is reshaped to (N, -1); any (N, 1) is fine
layer.w2_weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
self.method.process_weights_after_loading(layer)
mock_npu_format_cast.assert_called_once()

View File

@@ -17,12 +17,16 @@
import torch
import torch.nn.functional as F
import torch_npu
from vllm_ascend.ops.activation import AscendSiluAndMul
class AscendSiluAndMul310(AscendSiluAndMul):
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = x.shape[-1] // 2
out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16)
if x.shape[-1] % 32 == 0:
out = torch_npu.npu_swiglu(x)
else:
h = x.shape[-1] // 2
out = F.silu(x[..., :h]) * x[..., h:]
return out

View File

@@ -0,0 +1,65 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
import torch.nn as nn
import torch_npu
from vllm.model_executor.layers.linear import (
LinearBase,
QuantizeMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
class AscendUnquantizedLinearMethod310(UnquantizedLinearMethod):
def process_weights_after_loading(self, layer: nn.Module) -> None:
super().process_weights_after_loading(layer)
if "conv1d" not in getattr(layer, "prefix", ""):
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
class AscendLinearBase310(LinearBase):
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: object | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
nn.Module.__init__(self)
self.input_size = int(input_size)
self.output_size = int(output_size)
self.skip_bias_add = skip_bias_add
self.params_dtype = torch.float16
self.quant_config = quant_config
self.prefix = prefix
self.return_bias = return_bias
self.disable_tp = disable_tp
if quant_config is None:
self.quant_method: QuantizeMethodBase | None = AscendUnquantizedLinearMethod310()
else:
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)

View File

@@ -21,6 +21,7 @@ import torch
import torch_npu
from vllm_ascend.quantization.methods.base import AscendLinearScheme
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from .registry import register_scheme
@@ -72,9 +73,15 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
quant_bias = layer.quant_bias if tp_rank == 0 else None
# NOTE(310P):
# - Current torch_npu.npu_quant_matmul on Ascend 310P expects the weight layout in a transposed form
# for correct/efficient execution, so we pass `layer.weight.T` here.
# - This is a temporary workaround. The planned replacement quant-matmul op will accept the
# canonical (non-transposed) weight layout directly, so this explicit transpose will be removed
# once that op is enabled on 310P.
return torch_npu.npu_quant_matmul(
x,
layer.weight,
layer.weight.data,
layer.deq_scale,
bias=quant_bias,
output_dtype=layer.params_dtype,
@@ -82,6 +89,8 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
expanding_factor = layer.weight.data.shape[1]
# ---- quant stage tensors ----
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False,
@@ -95,7 +104,9 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
requires_grad=False,
).to(layer.aclnn_input_scale.dtype)
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
# ---- matmul stage tensor ----
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ).transpose(0, 1)
# ---- dequant stage tensors ----
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)

View File

@@ -104,9 +104,9 @@ class AscendModelSlimConfig310(AscendModelSlimConfig):
if isinstance(layer, LinearBase):
packed = getattr(self, "packed_modules_mapping", {})
if self.is_layer_skipped_ascend(prefix, packed):
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
return AscendUnquantizedLinearMethod()
return AscendUnquantizedLinearMethod310()
scheme = create_scheme_for_layer(
quant_description=self.quant_description,

View File

@@ -641,6 +641,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
}
)
REGISTERED_ASCEND_OPS.pop("MRotaryEmbedding", None)
for name, op_cls in REGISTERED_ASCEND_OPS.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)