[BugFix][310p][Cherry-pick] Handle null quantization config in ShardedStateLoader310&[Feature][310P] Support W8A8 dynamic linear method (#8296)

### What this PR does / why we need it?
This PR implements the `AscendW8A8DynamicLinearMethod310` quantization
scheme specifically for 310P hardware. It includes the logic for weight
retrieval, per-channel parameter generation, and the application of
dynamic quantization using NPU-specific kernels. Additionally, it
updates `ShardedStateLoader310` to handle quantization configurations
more robustly when generating parameter type maps.

Feedback from the review identified two critical issues in the
implementation:
1. The tensor squeezing logic in the `apply` method incorrectly handles
2D inputs, which may lead to shape mismatches in subsequent layers.
2. The weight tensor in `process_weights_after_loading` is transposed
after being converted to the private NZ format; the transpose operation
should be performed on the ND tensor before conversion to ensure correct
physical layout.

cherry-pick from : #7546 #7725
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New unit tests were added in
`tests/ut/_310p/quantization/test_w8a8_dynamic_310.py` to verify the
quantization method, and
`tests/ut/_310p/test_sharded_state_loader_310p.py` was updated to test
the state loader changes.

---------

Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
csoulnd
2026-04-16 16:53:39 +08:00
committed by GitHub
parent 52f0f9b5e4
commit 8952fddc7e
6 changed files with 184 additions and 10 deletions

View File

@@ -13,12 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend._310p.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod310
from vllm_ascend._310p.quantization.methods.w8a8_dynamic import (
AscendW8A8DynamicFusedMoEMethod310,
AscendW8A8DynamicLinearMethod310,
)
class TestAscendW8A8FusedMoEMethod310(TestBase):
@@ -64,3 +67,78 @@ class TestAscendW8A8FusedMoEMethod310(TestBase):
self.assertEqual(param_dict["w13_weight_scale"].shape, (self.num_experts, 2 * self.intermediate_size, 1))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.float32)
self.assertEqual(param_dict["w2_weight_scale"].shape, (self.num_experts, self.hidden_size, 1))
class TestAscendW8A8DynamicLinearMethod310(TestBase):
def setUp(self):
self.method = AscendW8A8DynamicLinearMethod310()
def test_get_weight_310(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (20, 10))
def test_get_perchannel_param_310(self):
params = self.method.get_perchannel_param(10, torch.float32)
self.assertEqual(params["weight_scale"].dtype, torch.float32)
self.assertEqual(params["weight_offset"].dtype, torch.float32)
self.assertEqual(params["weight_scale"].shape, (10, 1))
self.assertEqual(params["weight_offset"].shape, (10, 1))
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_quant_matmul")
def test_apply_310(self, mock_npu_quant_matmul, mock_npu_dynamic_quantize):
layer = MagicMock()
layer.weight = torch.randn(128, 256, dtype=torch.float16)
layer.weight_scale = torch.randn(128, dtype=torch.float32)
layer.params_dtype = torch.float16
x = torch.randn(32, 128, dtype=torch.float16)
expect_x_output = torch.randint(-128, 127, x.shape, dtype=torch.int8)
expect_pertoken_scale_output = torch.randn(x.shape[0], dtype=torch.float32)
mock_npu_dynamic_quantize.return_value = expect_x_output, expect_pertoken_scale_output
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, tp_rank=0)
mock_npu_dynamic_quantize.assert_called_with(x)
mock_npu_quant_matmul.assert_called_once()
(args, kwargs) = mock_npu_quant_matmul.call_args
# positional args
self.assertTrue(torch.equal(args[0], expect_x_output))
self.assertTrue(torch.equal(args[1], layer.weight.data))
self.assertTrue(torch.equal(args[2], layer.weight_scale))
# kwargs
self.assertTrue(torch.equal(kwargs["pertoken_scale"], expect_pertoken_scale_output))
self.assertTrue(kwargs["bias"] is None)
self.assertEqual(kwargs["output_dtype"], layer.params_dtype)
self.assertTrue(torch.equal(output, expected_y_output))
@patch("torch_npu.npu_format_cast")
def test_process_weights_after_loading_calls_nz_format_cast_310p(self, mock_npu_format_cast):
mock_npu_format_cast.side_effect = lambda x, fmt: x
layer = MagicMock()
# Attributes used by process_weights_after_loading()
layer.weight = MagicMock()
layer.weight_scale = MagicMock()
layer.weight_offset = MagicMock()
layer.weight.data = torch.randint(-127, 128, (128, 256), dtype=torch.int8)
layer.weight_scale.data = torch.randn(128, 1, dtype=torch.bfloat16)
layer.weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
# w2_weight_offset is reshaped to (N, -1); any (N, 1) is fine
layer.w2_weight_offset.data = torch.randn(128, 1, dtype=torch.bfloat16)
self.method.process_weights_after_loading(layer)
mock_npu_format_cast.assert_called_once()