Files
xc-llm-ascend/vllm_ascend/_310p/sharded_state_loader_310p.py
csoulnd 8952fddc7e [BugFix][310p][Cherry-pick] Handle null quantization config in ShardedStateLoader310&[Feature][310P] Support W8A8 dynamic linear method (#8296)
### What this PR does / why we need it?
This PR implements the `AscendW8A8DynamicLinearMethod310` quantization
scheme specifically for 310P hardware. It includes the logic for weight
retrieval, per-channel parameter generation, and the application of
dynamic quantization using NPU-specific kernels. Additionally, it
updates `ShardedStateLoader310` to handle quantization configurations
more robustly when generating parameter type maps.

Feedback from the review identified two critical issues in the
implementation:
1. The tensor squeezing logic in the `apply` method incorrectly handles
2D inputs, which may lead to shape mismatches in subsequent layers.
2. The weight tensor in `process_weights_after_loading` is transposed
after being converted to the private NZ format; the transpose operation
should be performed on the ND tensor before conversion to ensure correct
physical layout.

cherry-pick from : #7546 #7725
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New unit tests were added in
`tests/ut/_310p/quantization/test_w8a8_dynamic_310.py` to verify the
quantization method, and
`tests/ut/_310p/test_sharded_state_loader_310p.py` was updated to test
the state loader changes.

---------

Signed-off-by: csoulnd <daidaicurry@foxmail.com>
2026-04-16 16:53:39 +08:00

81 lines
2.9 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import json
import os
from pathlib import Path
import torch
from vllm.config.load import LoadConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader import ShardedStateLoader
class ShardedStateLoader310(ShardedStateLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
rank = get_tensor_model_parallel_rank()
part_idx = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
filename = ShardedStateLoader.DEFAULT_PATTERN.format(rank=rank, part=part_idx)
save_file(
state_dict,
os.path.join(path, filename),
)
@staticmethod
def generate_quant_description(
model: torch.nn.Module,
path: str,
quant_config: QuantizationConfig | None = None,
) -> None:
"""Generate a mapping of parameter names to their corresponding quantization types."""
quant_description = {}
if quant_config is None:
quantize_type = "FLOAT"
else:
try:
quantize_type = quant_config.quant_description.get("model_quant_type", "FLOAT")
except AttributeError:
quantize_type = "FLOAT"
quant_description["model_quant_type"] = quantize_type
quant_description["version"] = "1.0.0"
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
for name, tensor in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
if tensor.dtype in [torch.int8, torch.int32, torch.int64]:
quant_description[name] = quantize_type
else:
quant_description[name] = "FLOAT"
else:
quant_description[name] = "FLOAT"
json_path = Path(path) / "parameters_type_map.json"
with json_path.open("w", encoding="utf-8") as f:
json.dump(quant_description, f, indent=2)