[BugFix][310p][Cherry-pick] Handle null quantization config in ShardedStateLoader310&[Feature][310P] Support W8A8 dynamic linear method (#8296)
### What this PR does / why we need it? This PR implements the `AscendW8A8DynamicLinearMethod310` quantization scheme specifically for 310P hardware. It includes the logic for weight retrieval, per-channel parameter generation, and the application of dynamic quantization using NPU-specific kernels. Additionally, it updates `ShardedStateLoader310` to handle quantization configurations more robustly when generating parameter type maps. Feedback from the review identified two critical issues in the implementation: 1. The tensor squeezing logic in the `apply` method incorrectly handles 2D inputs, which may lead to shape mismatches in subsequent layers. 2. The weight tensor in `process_weights_after_loading` is transposed after being converted to the private NZ format; the transpose operation should be performed on the ND tensor before conversion to ensure correct physical layout. cherry-pick from : #7546 #7725 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests were added in `tests/ut/_310p/quantization/test_w8a8_dynamic_310.py` to verify the quantization method, and `tests/ut/_310p/test_sharded_state_loader_310p.py` was updated to test the state loader changes. --------- Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
@@ -77,7 +77,7 @@ class TestShardedStateLoader310(TestBase):
|
||||
model = MockModel(quant_config=quant_config, with_int_weights=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ShardedStateLoader310.generate_quant_description(model, tmpdir)
|
||||
ShardedStateLoader310.generate_quant_description(model, tmpdir, quant_config)
|
||||
|
||||
json_path = Path(tmpdir) / "parameters_type_map.json"
|
||||
self.assertTrue(json_path.exists())
|
||||
@@ -92,6 +92,24 @@ class TestShardedStateLoader310(TestBase):
|
||||
self.assertIn("linear.bias", quant_description)
|
||||
self.assertEqual(quant_description["linear.bias"], "FLOAT")
|
||||
|
||||
@patch("vllm.model_executor.model_loader.ShardedStateLoader._filter_subtensors")
|
||||
def test_generate_quant_description_no_quant_config_310(self, mock_filter):
|
||||
"""When quant_config is None, treat model as FLOAT."""
|
||||
mock_filter.side_effect = lambda x: x
|
||||
model = MockModel(quant_config=None, with_int_weights=False)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ShardedStateLoader310.generate_quant_description(model, tmpdir, None)
|
||||
|
||||
json_path = Path(tmpdir) / "parameters_type_map.json"
|
||||
self.assertTrue(json_path.exists())
|
||||
|
||||
with open(json_path, encoding="utf-8") as f:
|
||||
quant_description = json.load(f)
|
||||
|
||||
self.assertEqual(quant_description["model_quant_type"], "FLOAT")
|
||||
self.assertEqual(quant_description["linear.weight"], "FLOAT")
|
||||
|
||||
@patch("vllm.model_executor.model_loader.ShardedStateLoader._filter_subtensors")
|
||||
def test_generate_quant_description_int_model_310(self, mock_filter):
|
||||
"""Test generate_quant_description for int8 quantized model."""
|
||||
@@ -100,7 +118,7 @@ class TestShardedStateLoader310(TestBase):
|
||||
model = MockModel(quant_config=quant_config, with_int_weights=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
ShardedStateLoader310.generate_quant_description(model, tmpdir)
|
||||
ShardedStateLoader310.generate_quant_description(model, tmpdir, quant_config)
|
||||
|
||||
json_path = Path(tmpdir) / "parameters_type_map.json"
|
||||
self.assertTrue(json_path.exists())
|
||||
|
||||
Reference in New Issue
Block a user