[BugFix][310p][Cherry-pick] Handle null quantization config in ShardedStateLoader310&[Feature][310P] Support W8A8 dynamic linear method (#8296)

### What this PR does / why we need it?
This PR implements the `AscendW8A8DynamicLinearMethod310` quantization
scheme specifically for 310P hardware. It includes the logic for weight
retrieval, per-channel parameter generation, and the application of
dynamic quantization using NPU-specific kernels. Additionally, it
updates `ShardedStateLoader310` to handle quantization configurations
more robustly when generating parameter type maps.

Feedback from the review identified two critical issues in the
implementation:
1. The tensor squeezing logic in the `apply` method incorrectly handles
2D inputs, which may lead to shape mismatches in subsequent layers.
2. The weight tensor in `process_weights_after_loading` is transposed
after being converted to the private NZ format; the transpose operation
should be performed on the ND tensor before conversion to ensure correct
physical layout.

cherry-pick from : #7546 #7725
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New unit tests were added in
`tests/ut/_310p/quantization/test_w8a8_dynamic_310.py` to verify the
quantization method, and
`tests/ut/_310p/test_sharded_state_loader_310p.py` was updated to test
the state loader changes.

---------

Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
csoulnd
2026-04-16 16:53:39 +08:00
committed by GitHub
parent 52f0f9b5e4
commit 8952fddc7e
6 changed files with 184 additions and 10 deletions

View File

@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend._310p.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod310
from vllm_ascend._310p.quantization.methods.w8a8_dynamic import (
AscendW8A8DynamicFusedMoEMethod310,
AscendW8A8DynamicLinearMethod310,
)
class TestAscendW8A8FusedMoEMethod310(TestBase):
@@ -64,3 +67,78 @@ class TestAscendW8A8FusedMoEMethod310(TestBase):
self.assertEqual(param_dict["w13_weight_scale"].shape, (self.num_experts, 2 * self.intermediate_size, 1))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w2_weight_scale"].shape, (self.num_experts, self.hidden_size, 1))
class TestAscendW8A8DynamicLinearMethod310(TestBase):
def setUp(self):
self.method = AscendW8A8DynamicLinearMethod310()
def test_get_weight_310(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (20, 10))
def test_get_perchannel_param_310(self):
params = self.method.get_perchannel_param(10, torch.float32)
self.assertEqual(params["weight_scale"].dtype, torch.float32)
self.assertEqual(params["weight_offset"].dtype, torch.float32)
self.assertEqual(params["weight_scale"].shape, (10, 1))
self.assertEqual(params["weight_offset"].shape, (10, 1))
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_quant_matmul")
def test_apply_310(self, mock_npu_quant_matmul, mock_npu_dynamic_quantize):
layer = MagicMock()
layer.weight = torch.randn(128, 256, dtype=torch.float16)
layer.weight_scale = torch.randn(128, dtype=torch.float32)
layer.params_dtype = torch.float16
x = torch.randn(32, 128, dtype=torch.float16)
expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8)
expect_pertoken_scale_output = torch.randn(x.shape[0], dtype=torch.float32)
mock_npu_dynamic_quantize.return_value = expect_x_output, expect_pertoken_scale_output
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, tp_rank=0)
mock_npu_dynamic_quantize.assert_called_with(x)
mock_npu_quant_matmul.assert_called_once()
(args, kwargs) = mock_npu_quant_matmul.call_args
# positional args
self.assertTrue(torch.equal(args[0], expect_x_output))
self.assertTrue(torch.equal(args[1], layer.weight.data))
self.assertTrue(torch.equal(args[2], layer.weight_scale))
# kwargs
self.assertTrue(torch.equal(kwargs["pertoken_scale"], expect_pertoken_scale_output))
self.assertTrue(kwargs["bias"] is None)
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
self.assertTrue(torch.equal(output, expected_y_output))
@patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_calls_nz_format_cast_310p(self, mock_npu_format_cast):
mock_npu_format_cast.side_effect = lambda x, fmt: x
layer = MagicMock()
# Attributes used by process_weights_after_loading()
layer.weight = MagicMock()
layer.weight_scale = MagicMock()
layer.weight_offset = MagicMock()
layer.weight.data = torch.randint(-127, 128, (128, 256), dtype=torch.int8)
layer.weight_scale.data = torch.randn(128, 1, dtype=torch.bfloat16)
layer.weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
# w2_weight_offset is reshaped to (N, -1); any (N, 1) is fine
layer.w2_weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
self.method.process_weights_after_loading(layer)
mock_npu_format_cast.assert_called_once()

View File

@@ -77,7 +77,7 @@ class TestShardedStateLoader310(TestBase):
model = MockModel(quant_config=quant_config, with_int_weights=False)
with tempfile.TemporaryDirectory() as tmpdir:
ShardedStateLoader310.generate_quant_description(model, tmpdir)
ShardedStateLoader310.generate_quant_description(model, tmpdir, quant_config)
json_path = Path(tmpdir) / "parameters_type_map.json"
self.assertTrue(json_path.exists())
@@ -92,6 +92,24 @@ class TestShardedStateLoader310(TestBase):
self.assertIn("linear.bias", quant_description)
self.assertEqual(quant_description["linear.bias"], "FLOAT")
@patch("vllm.model_executor.model_loader.ShardedStateLoader._filter_subtensors")
def test_generate_quant_description_no_quant_config_310(self, mock_filter):
"""When quant_config is None, treat model as FLOAT."""
mock_filter.side_effect = lambda x: x
model = MockModel(quant_config=None, with_int_weights=False)
with tempfile.TemporaryDirectory() as tmpdir:
ShardedStateLoader310.generate_quant_description(model, tmpdir, None)
json_path = Path(tmpdir) / "parameters_type_map.json"
self.assertTrue(json_path.exists())
with open(json_path, encoding="utf-8") as f:
quant_description = json.load(f)
self.assertEqual(quant_description["model_quant_type"], "FLOAT")
self.assertEqual(quant_description["linear.weight"], "FLOAT")
@patch("vllm.model_executor.model_loader.ShardedStateLoader._filter_subtensors")
def test_generate_quant_description_int_model_310(self, mock_filter):
"""Test generate_quant_description for int8 quantized model."""
@@ -100,7 +118,7 @@ class TestShardedStateLoader310(TestBase):
model = MockModel(quant_config=quant_config, with_int_weights=True)
with tempfile.TemporaryDirectory() as tmpdir:
ShardedStateLoader310.generate_quant_description(model, tmpdir)
ShardedStateLoader310.generate_quant_description(model, tmpdir, quant_config)
json_path = Path(tmpdir) / "parameters_type_map.json"
self.assertTrue(json_path.exists())

View File

@@ -19,6 +19,7 @@ from collections.abc import Callable
from typing import Any
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group
@@ -26,7 +27,8 @@ from vllm_ascend._310p.fused_moe.experts_selector import select_experts
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
from vllm_ascend.quantization.methods.base import AscendLinearScheme, AscendMoEScheme, QuantType
from vllm_ascend.utils import maybe_trans_nz
from .registry import register_scheme
@@ -154,3 +156,66 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod310(AscendLinearScheme):
"""310P-only W8A8 dynamic linear scheme.
Notes:
- This scheme is discovered via 310P local registry.
"""
def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.float16,
) -> dict[str, Any]:
return {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> dict[str, Any]:
params: dict[str, Any] = {}
params["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32)
params["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32)
return params
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
# NOTE(310P):
# - There is an accuracy issue currently, which is expected to be fixed in the next version.
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
need_unsqz = False
if pertoken_scale.dim() == 2:
need_unsqz = True
quantized_x = quantized_x.squeeze(dim=1)
pertoken_scale = pertoken_scale.squeeze(dim=1)
# NOTE(310P):
# - Currently, W8A8 dynamic quantization supports only symmetric quantization.
output = torch_npu.npu_quant_matmul(
quantized_x,
layer.weight.data,
layer.weight_scale,
pertoken_scale=pertoken_scale,
bias=bias,
output_dtype=x.dtype,
)
if need_unsqz:
output = output.unsqueeze(dim=1)
return output
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# cast quantized weight tensors in NZ format for higher inference speed
layer.weight.data = maybe_trans_nz(layer.weight.data).transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_offset.data = layer.weight_offset.data.flatten()

View File

@@ -95,8 +95,6 @@ class AscendModelSlimConfig310(AscendModelSlimConfig):
self.packed_modules_mapping = packed_modules_model_mapping[model_type]
prefix = self.quant_prefix_mapper(model_type, prefix)
if prefix.startswith("language_model"):
prefix = prefix.split(".", 1)[-1]
if isinstance(layer, LinearBase):
packed = getattr(self, "packed_modules_mapping", {})

View File

@@ -20,6 +20,7 @@ from pathlib import Path
import torch
from vllm.config.load import LoadConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader import ShardedStateLoader
@@ -48,10 +49,20 @@ class ShardedStateLoader310(ShardedStateLoader):
)
@staticmethod
def generate_quant_description(model: torch.nn.Module, path: str):
def generate_quant_description(
model: torch.nn.Module,
path: str,
quant_config: QuantizationConfig | None = None,
) -> None:
"""Generate a mapping of parameter names to their corresponding quantization types."""
quant_description = {}
quantize_type = model.quant_config.quant_description.get("model_quant_type", "FLOAT")
if quant_config is None:
quantize_type = "FLOAT"
else:
try:
quantize_type = quant_config.quant_description.get("model_quant_type", "FLOAT")
except AttributeError:
quantize_type = "FLOAT"
quant_description["model_quant_type"] = quantize_type
quant_description["version"] = "1.0.0"
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())

View File

@@ -48,7 +48,11 @@ class NPUWorker310(NPUWorker):
max_size=max_size,
)
ShardedStateLoader310.generate_quant_description(self.model_runner.model, path)
ShardedStateLoader310.generate_quant_description(
self.model_runner.model,
path,
self.vllm_config.quant_config,
)
@torch.inference_mode()
def determine_available_memory(self) -> int: