[Feat.][310P]: weightNZ feature with quant or unquant. (#6705)
NZ Format Support for Linear Layers: Implemented support for the NZ
(N-dimensional Z-order) format for linear layer weights on Ascend 310P,
enhancing performance for both quantized and unquantized layers.
Unquantized Linear Method for Ascend 310P: Introduced
AscendUnquantizedLinearMethod310 to specifically handle and apply NZ
format casting to unquantized linear layer weights during the loading
process.
MRotaryEmbedding Integration: Extended Rotary Embedding support by
adding AscendMRotaryEmbedding310 to provide an Ascend-specific
implementation for MRotaryEmbedding.
Quantization Method Updates: Updated the w8a8_static quantization method
to directly transpose weights and apply NZ format casting, ensuring
consistency with the new format.
- vLLM version: v0.15.0
- vLLM main:
9562912cea
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
This commit is contained in:
@@ -21,8 +21,8 @@ from vllm.model_executor.layers.linear import LinearBase
|
|||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
from vllm_ascend._310p.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod310
|
from vllm_ascend._310p.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod310
|
||||||
|
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
|
||||||
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
|
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
|
||||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
|
||||||
|
|
||||||
|
|
||||||
class TestAscendModelSlimConfig310(TestBase):
|
class TestAscendModelSlimConfig310(TestBase):
|
||||||
@@ -50,7 +50,7 @@ class TestAscendModelSlimConfig310(TestBase):
|
|||||||
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True),
|
patch.object(self.ascend_config, "is_layer_skipped_ascend", return_value=True),
|
||||||
):
|
):
|
||||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||||
self.assertIsInstance(method, AscendUnquantizedLinearMethod)
|
self.assertIsInstance(method, AscendUnquantizedLinearMethod310)
|
||||||
|
|
||||||
# Test quantized layer
|
# Test quantized layer
|
||||||
mock_scheme = MagicMock()
|
mock_scheme = MagicMock()
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ class TestAscendW8A8LinearMethod310(TestBase):
|
|||||||
self.assertEqual(params["deq_scale"].dtype, torch.int64)
|
self.assertEqual(params["deq_scale"].dtype, torch.int64)
|
||||||
self.assertEqual(params["weight_scale"].dtype, torch.float16)
|
self.assertEqual(params["weight_scale"].dtype, torch.float16)
|
||||||
self.assertEqual(params["weight_offset"].dtype, torch.float16)
|
self.assertEqual(params["weight_offset"].dtype, torch.float16)
|
||||||
|
|
||||||
self.assertEqual(params["quant_bias"].shape, (10,))
|
self.assertEqual(params["quant_bias"].shape, (10,))
|
||||||
self.assertEqual(params["deq_scale"].shape, (10,))
|
self.assertEqual(params["deq_scale"].shape, (10,))
|
||||||
self.assertEqual(params["weight_scale"].shape, (10, 1))
|
self.assertEqual(params["weight_scale"].shape, (10, 1))
|
||||||
@@ -71,12 +72,23 @@ class TestAscendW8A8LinearMethod310(TestBase):
|
|||||||
output = self.method.apply(layer, x, tp_rank=0)
|
output = self.method.apply(layer, x, tp_rank=0)
|
||||||
|
|
||||||
mock_quantize.assert_called_with(
|
mock_quantize.assert_called_with(
|
||||||
x, layer.aclnn_input_scale, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset
|
x,
|
||||||
|
layer.aclnn_input_scale,
|
||||||
|
layer.aclnn_input_scale_reciprocal,
|
||||||
|
layer.aclnn_input_offset,
|
||||||
)
|
)
|
||||||
mock_npu_quant_matmul.assert_called_with(
|
mock_npu_quant_matmul.assert_called_once()
|
||||||
expect_x_output, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
|
(args, kwargs) = mock_npu_quant_matmul.call_args
|
||||||
)
|
|
||||||
# The bias is added by the linear layer's forward pass, not the quant method.
|
# positional args
|
||||||
|
self.assertTrue(torch.equal(args[0], expect_x_output))
|
||||||
|
self.assertTrue(torch.equal(args[1], layer.weight.data))
|
||||||
|
self.assertTrue(torch.equal(args[2], layer.deq_scale))
|
||||||
|
|
||||||
|
# kwargs
|
||||||
|
self.assertTrue(torch.equal(kwargs["bias"], layer.quant_bias))
|
||||||
|
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
|
||||||
|
|
||||||
self.assertTrue(torch.equal(output, expected_y_output))
|
self.assertTrue(torch.equal(output, expected_y_output))
|
||||||
|
|
||||||
@patch("torch.ops.vllm.quantize")
|
@patch("torch.ops.vllm.quantize")
|
||||||
@@ -98,8 +110,41 @@ class TestAscendW8A8LinearMethod310(TestBase):
|
|||||||
output = self.method.apply(layer, x, tp_rank=0)
|
output = self.method.apply(layer, x, tp_rank=0)
|
||||||
|
|
||||||
mock_quantize.assert_not_called()
|
mock_quantize.assert_not_called()
|
||||||
mock_npu_quant_matmul.assert_called_with(
|
mock_npu_quant_matmul.assert_called_once()
|
||||||
x, layer.weight, layer.deq_scale, bias=layer.quant_bias, output_dtype=layer.params_dtype
|
(args, kwargs) = mock_npu_quant_matmul.call_args
|
||||||
)
|
|
||||||
# The bias is added by the linear layer's forward pass, not the quant method.
|
self.assertTrue(torch.equal(args[0], x))
|
||||||
|
self.assertTrue(torch.equal(args[1], layer.weight.data))
|
||||||
|
self.assertTrue(torch.equal(args[2], layer.deq_scale))
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(kwargs["bias"], layer.quant_bias))
|
||||||
|
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
|
||||||
|
|
||||||
self.assertTrue(torch.equal(output, expected_y_output))
|
self.assertTrue(torch.equal(output, expected_y_output))
|
||||||
|
|
||||||
|
@patch("torch_npu.npu_format_cast")
|
||||||
|
def test_process_weights_after_loading_calls_nz_format_cast_310p(self, mock_npu_format_cast):
|
||||||
|
mock_npu_format_cast.side_effect = lambda x, fmt: x
|
||||||
|
|
||||||
|
layer = MagicMock()
|
||||||
|
|
||||||
|
# Attributes used by process_weights_after_loading()
|
||||||
|
layer.weight = MagicMock()
|
||||||
|
layer.input_scale = MagicMock()
|
||||||
|
layer.input_offset = MagicMock()
|
||||||
|
layer.weight_scale = MagicMock()
|
||||||
|
layer.weight_offset = MagicMock()
|
||||||
|
layer.w2_weight_offset = MagicMock()
|
||||||
|
|
||||||
|
layer.weight.data = torch.randint(-127, 128, (128, 256), dtype=torch.int8)
|
||||||
|
layer.input_scale.data = torch.tensor([0.1], dtype=torch.float16)
|
||||||
|
layer.input_offset.data = torch.tensor([0], dtype=torch.int8)
|
||||||
|
|
||||||
|
layer.weight_scale.data = torch.randn(128, 1, dtype=torch.bfloat16)
|
||||||
|
layer.weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
|
||||||
|
# w2_weight_offset is reshaped to (N, -1); any (N, 1) is fine
|
||||||
|
layer.w2_weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
self.method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
mock_npu_format_cast.assert_called_once()
|
||||||
|
|||||||
@@ -17,12 +17,16 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
from vllm_ascend.ops.activation import AscendSiluAndMul
|
from vllm_ascend.ops.activation import AscendSiluAndMul
|
||||||
|
|
||||||
|
|
||||||
class AscendSiluAndMul310(AscendSiluAndMul):
|
class AscendSiluAndMul310(AscendSiluAndMul):
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if x.shape[-1] % 32 == 0:
|
||||||
|
out = torch_npu.npu_swiglu(x)
|
||||||
|
else:
|
||||||
h = x.shape[-1] // 2
|
h = x.shape[-1] // 2
|
||||||
out = (F.silu(x[..., :h].to(torch.float32)) * x[..., h:].to(torch.float32)).to(torch.float16)
|
out = F.silu(x[..., :h]) * x[..., h:]
|
||||||
return out
|
return out
|
||||||
|
|||||||
65
vllm_ascend/_310p/ops/linear.py
Normal file
65
vllm_ascend/_310p/ops/linear.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch_npu
|
||||||
|
from vllm.model_executor.layers.linear import (
|
||||||
|
LinearBase,
|
||||||
|
QuantizeMethodBase,
|
||||||
|
UnquantizedLinearMethod,
|
||||||
|
)
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|
||||||
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||||
|
|
||||||
|
|
||||||
|
class AscendUnquantizedLinearMethod310(UnquantizedLinearMethod):
|
||||||
|
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
||||||
|
super().process_weights_after_loading(layer)
|
||||||
|
if "conv1d" not in getattr(layer, "prefix", ""):
|
||||||
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendLinearBase310(LinearBase):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: int,
|
||||||
|
output_size: int,
|
||||||
|
skip_bias_add: bool = False,
|
||||||
|
params_dtype: object | None = None,
|
||||||
|
quant_config: QuantizationConfig | None = None,
|
||||||
|
prefix: str = "",
|
||||||
|
*,
|
||||||
|
return_bias: bool = True,
|
||||||
|
disable_tp: bool = False,
|
||||||
|
):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
self.input_size = int(input_size)
|
||||||
|
self.output_size = int(output_size)
|
||||||
|
self.skip_bias_add = skip_bias_add
|
||||||
|
self.params_dtype = torch.float16
|
||||||
|
self.quant_config = quant_config
|
||||||
|
self.prefix = prefix
|
||||||
|
self.return_bias = return_bias
|
||||||
|
self.disable_tp = disable_tp
|
||||||
|
|
||||||
|
if quant_config is None:
|
||||||
|
self.quant_method: QuantizeMethodBase | None = AscendUnquantizedLinearMethod310()
|
||||||
|
else:
|
||||||
|
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
||||||
@@ -21,6 +21,7 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
from vllm_ascend.quantization.methods.base import AscendLinearScheme
|
from vllm_ascend.quantization.methods.base import AscendLinearScheme
|
||||||
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
|
||||||
|
|
||||||
from .registry import register_scheme
|
from .registry import register_scheme
|
||||||
|
|
||||||
@@ -72,9 +73,15 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
|
|||||||
|
|
||||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||||
|
|
||||||
|
# NOTE(310P):
|
||||||
|
# - Current torch_npu.npu_quant_matmul on Ascend 310P expects the weight layout in a transposed form
|
||||||
|
# for correct/efficient execution, so we pass `layer.weight.T` here.
|
||||||
|
# - This is a temporary workaround. The planned replacement quant-matmul op will accept the
|
||||||
|
# canonical (non-transposed) weight layout directly, so this explicit transpose will be removed
|
||||||
|
# once that op is enabled on 310P.
|
||||||
return torch_npu.npu_quant_matmul(
|
return torch_npu.npu_quant_matmul(
|
||||||
x,
|
x,
|
||||||
layer.weight,
|
layer.weight.data,
|
||||||
layer.deq_scale,
|
layer.deq_scale,
|
||||||
bias=quant_bias,
|
bias=quant_bias,
|
||||||
output_dtype=layer.params_dtype,
|
output_dtype=layer.params_dtype,
|
||||||
@@ -82,6 +89,8 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
|
|||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
expanding_factor = layer.weight.data.shape[1]
|
expanding_factor = layer.weight.data.shape[1]
|
||||||
|
|
||||||
|
# ---- quant stage tensors ----
|
||||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||||
layer.input_scale.data.repeat(expanding_factor),
|
layer.input_scale.data.repeat(expanding_factor),
|
||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
@@ -95,7 +104,9 @@ class AscendW8A8LinearMethod310(AscendLinearScheme):
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
).to(layer.aclnn_input_scale.dtype)
|
).to(layer.aclnn_input_scale.dtype)
|
||||||
|
|
||||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
# ---- matmul stage tensor ----
|
||||||
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ).transpose(0, 1)
|
||||||
|
|
||||||
|
# ---- dequant stage tensors ----
|
||||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||||
|
|||||||
@@ -104,9 +104,9 @@ class AscendModelSlimConfig310(AscendModelSlimConfig):
|
|||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
packed = getattr(self, "packed_modules_mapping", {})
|
packed = getattr(self, "packed_modules_mapping", {})
|
||||||
if self.is_layer_skipped_ascend(prefix, packed):
|
if self.is_layer_skipped_ascend(prefix, packed):
|
||||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
from vllm_ascend._310p.ops.linear import AscendUnquantizedLinearMethod310
|
||||||
|
|
||||||
return AscendUnquantizedLinearMethod()
|
return AscendUnquantizedLinearMethod310()
|
||||||
|
|
||||||
scheme = create_scheme_for_layer(
|
scheme = create_scheme_for_layer(
|
||||||
quant_description=self.quant_description,
|
quant_description=self.quant_description,
|
||||||
|
|||||||
@@ -641,6 +641,8 @@ def register_ascend_customop(vllm_config: VllmConfig | None = None):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
REGISTERED_ASCEND_OPS.pop("MRotaryEmbedding", None)
|
||||||
|
|
||||||
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
||||||
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user