[BugFix][310p][Cherry-pick] Handle null quantization config in ShardedStateLoader310&[Feature][310P] Support W8A8 dynamic linear method (#8296)

### What this PR does / why we need it?
This PR implements the `AscendW8A8DynamicLinearMethod310` quantization
scheme specifically for 310P hardware. It includes the logic for weight
retrieval, per-channel parameter generation, and the application of
dynamic quantization using NPU-specific kernels. Additionally, it
updates `ShardedStateLoader310` to handle quantization configurations
more robustly when generating parameter type maps.

Feedback from the review identified two critical issues in the
implementation:
1. The tensor squeezing logic in the `apply` method incorrectly handles
2D inputs, which may lead to shape mismatches in subsequent layers.
2. The weight tensor in `process_weights_after_loading` is transposed
after being converted to the private NZ format; the transpose operation
should be performed on the ND tensor before conversion to ensure correct
physical layout.

cherry-pick from : #7546 #7725
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New unit tests were added in
`tests/ut/_310p/quantization/test_w8a8_dynamic_310.py` to verify the
quantization method, and
`tests/ut/_310p/test_sharded_state_loader_310p.py` was updated to test
the state loader changes.

---------

Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
csoulnd
2026-04-16 16:53:39 +08:00
committed by GitHub
parent 52f0f9b5e4
commit 8952fddc7e
6 changed files with 184 additions and 10 deletions

View File

@@ -20,6 +20,7 @@ from pathlib import Path
import torch
from vllm.config.load import LoadConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader import ShardedStateLoader
@@ -48,10 +49,20 @@ class ShardedStateLoader310(ShardedStateLoader):
)
@staticmethod
def generate_quant_description(model: torch.nn.Module, path: str):
def generate_quant_description(
model: torch.nn.Module,
path: str,
quant_config: QuantizationConfig | None = None,
) -> None:
"""Generate a mapping of parameter names to their corresponding quantization types."""
quant_description = {}
quantize_type = model.quant_config.quant_description.get("model_quant_type", "FLOAT")
if quant_config is None:
quantize_type = "FLOAT"
else:
try:
quantize_type = quant_config.quant_description.get("model_quant_type", "FLOAT")
except AttributeError:
quantize_type = "FLOAT"
quant_description["model_quant_type"] = quantize_type
quant_description["version"] = "1.0.0"
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())