[Feat][310p] 310P support w8a8s quantization and saving w8a8sc state (#6878)

### What this PR does / why we need it?
This pull request introduces significant enhancements for 310P device
support, primarily by enabling W8A8S quantization and facilitating the
saving of models with W8A8SC state outputs. It provides an example
script for saving sharded and compressed model states, implements the
core W8A8S quantization method, and integrates metadata generation
within the 310P worker to accurately describe the quantization types of
saved parameters. These changes aim to improve efficiency and
compatibility for quantized models on 310P hardware.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
W8A8S accuarcy test and W8A8SC states save.
<img width="886" height="184" alt="image"
src="https://github.com/user-attachments/assets/e9bcac54-1f69-4d3a-a5b8-221a147ef99d"
/>

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
pu-zhe
2026-03-02 20:09:15 +08:00
committed by GitHub
parent 68d8d20ca2
commit 5899438a86
8 changed files with 668 additions and 1 deletions

View File

@@ -18,4 +18,5 @@
from . import (
w8a8_dynamic, # noqa: F401
w8a8_static, # noqa: F401
w8a8s, # noqa: F401
)

View File

@@ -0,0 +1,87 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Any
import torch
import torch_npu
from vllm_ascend.quantization.methods.base import AscendLinearScheme
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ
from .registry import register_scheme
@register_scheme("W8A8S", "linear")
class AscendW8A8SLinearMethod310(AscendLinearScheme):
"""310P-only W8A8S Sparse linear scheme.
Notes:
- This scheme is discovered via 310P local registry.
"""
def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.float16,
) -> dict[str, Any]:
return {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
return {
"input_scale": torch.empty(1, dtype=params_dtype),
"input_offset": torch.empty(1, dtype=torch.int8),
}
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
return {
"quant_bias": torch.empty(output_size, dtype=torch.int32),
"deq_scale": torch.empty(output_size, dtype=torch.int64),
}
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
x = torch.ops.vllm.quantize(
x,
layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
return torch_npu.npu_quant_matmul(
x,
layer.weight.data.transpose(0, 1),
layer.deq_scale,
bias=quant_bias,
output_dtype=layer.params_dtype,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = layer.input_scale.data.repeat(expanding_factor)
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data
layer.aclnn_input_offset = layer.input_offset.data.repeat(expanding_factor).to(layer.aclnn_input_scale.dtype)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, ACL_FORMAT_FRACTAL_NZ)

View File

@@ -0,0 +1,69 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import json
import os
from pathlib import Path
import torch
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader import ShardedStateLoader
class ShardedStateLoader310(ShardedStateLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
rank = get_tensor_model_parallel_rank()
part_idx = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
filename = ShardedStateLoader.DEFAULT_PATTERN.format(rank=rank, part=part_idx)
save_file(
state_dict,
os.path.join(path, filename),
)
@staticmethod
def generate_quant_description(model: torch.nn.Module, path: str):
"""Generate a mapping of parameter names to their corresponding quantization types."""
quant_description = {}
quantize_type = model.quant_config.quant_description.get("model_quant_type", "FLOAT")
quant_description["model_quant_type"] = quantize_type
quant_description["version"] = "1.0.0"
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
for name, tensor in state_dict.items():
if name.endswith(".weight") or name.endswith(".bias"):
if tensor.dtype in [torch.int8, torch.int32, torch.int64]:
quant_description[name] = quantize_type
else:
quant_description[name] = "FLOAT"
else:
quant_description[name] = "FLOAT"
json_path = Path(path) / "parameters_type_map.json"
with json_path.open("w", encoding="utf-8") as f:
json.dump(quant_description, f, indent=2)

View File

@@ -14,7 +14,6 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import torch_npu
from vllm.logger import logger
@@ -31,6 +30,23 @@ class NPUWorker310(NPUWorker):
self.model_runner = NPUModelRunner310(self.vllm_config, self.device)
def save_sharded_state(
self,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
from vllm_ascend._310p.sharded_state_loader_310p import ShardedStateLoader310
ShardedStateLoader310.save_model(
self.model_runner.model,
path,
pattern=pattern,
max_size=max_size,
)
ShardedStateLoader310.generate_quant_description(self.model_runner.model, path)
def _warm_up_atb(self):
# 310p device do not support torch_npu._npu_matmul_add_fp32 atb ops
logger.info("Skip warm-up atb ops for 310P device.")