Files
xc-llm-ascend/vllm_ascend/_310p/quantization/methods/w8a8sc.py
pu-zhe 5df450bca4 [Feat] [310p] Support w8a8sc quantization method (#7075)
### What this PR does / why we need it?
New Quantization Method: Introduced support for the W8A8SC static linear
quantization scheme specifically for 310P hardware, enabling more
efficient model compression.
Refactored the save_sharded_state_310.py to avoid multi-process issue.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
W8A8SC quant E2E test.

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

---------

Signed-off-by: pu-zhe <zpuaa@outlook.com>
2026-03-10 16:13:20 +08:00

117 lines
4.5 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
from typing import Any
import torch
import torch_npu
from vllm.distributed import get_tensor_model_parallel_rank
from vllm_ascend.ops.linear import AscendRowParallelLinear
from vllm_ascend.quantization.methods.base import AscendLinearScheme
from .registry import register_scheme
@register_scheme("W8A8SC", "linear")
class AscendW8A8SCLinearMethod310(AscendLinearScheme):
"""310P-only W8A8SC static linear scheme.
Notes:
- This scheme is discovered via 310P local registry.
"""
def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.float16,
) -> dict[str, Any]:
"""
Get the weight tensors for the W8A8SC quantization scheme.
Args:
input_size: Size of the input dimension (k)
output_size: Size of the output dimension (n)
params_dtype: Data type for parameters, default is torch.float16
Returns:
A dictionary containing:
- "weight": The compressed weight tensor with shape [c], where c is greater than 0
and not larger than k * n
- "index": Compression index generated simultaneously with compressed weights,
with shape [x], where x = k_index * n_index * 8, k_index = ceil(k1 / tilingK),
n_index = ceil(n1 / tilingN), k1 = k / 32, n1 = n / 16
- "info": Compression information with length 5, containing compression block
information tilingN, tilingK, original shape of the pre-compression x2 matrix,
and identifier for the compression block traversal direction
"""
self.input_size = input_size
index_len = math.ceil(input_size / 256) * math.ceil(output_size / 128) * 8
return {
"weight": torch.empty(input_size * output_size, dtype=torch.int8),
"index": torch.empty(index_len, dtype=torch.int8),
"info": torch.empty(5, dtype=torch.int64),
}
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
return {
"input_scale": torch.empty(1, dtype=params_dtype),
"input_offset": torch.empty(1, dtype=torch.int8),
}
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
return {
"quant_bias": torch.empty(output_size, dtype=torch.int32),
"deq_scale": torch.empty(output_size, dtype=torch.int64),
}
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
x = torch.ops.vllm.quantize(
x,
layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
return torch_npu.npu_matmul_compress_dequant(
x,
layer.weight,
layer.index,
layer.quant_bias,
layer.deq_scale,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.aclnn_input_scale = layer.input_scale.data.repeat(self.input_size)
layer.aclnn_input_scale_reciprocal = 1.0 / layer.aclnn_input_scale.data
layer.aclnn_input_offset = layer.input_offset.data.repeat(self.input_size).to(layer.aclnn_input_scale.dtype)
layer.deq_scale.data = layer.deq_scale.data.unsqueeze(0).to(torch.uint64)
layer.quant_bias.data = layer.quant_bias.data.unsqueeze(0)
# Only apply bias on row_parallel_linear when tp_rank is 0.
# torch_npu.npu_matmul_compress_dequant's quant_bias cannot be None.
if isinstance(layer, AscendRowParallelLinear) and get_tensor_model_parallel_rank() != 0:
layer.quant_bias.data = torch.zeros_like(layer.quant_bias)