Files
xc-llm-ascend/vllm_ascend/_310p/quantization/methods/w8a8_static.py
Shaoxu Cheng 39e77fb9e4 [Feat.]: support 310p w8a8 (#6454)
### What this PR does / why we need it?
Introduced 310P W8A8 Quantization Support: New modules and methods have
been added to enable W8A8 static quantization specifically for the
Ascend 310P platform.
Platform-Specific Quantization Configuration Loading: The system now
dynamically loads the appropriate quantization configurations
(AscendCompressedTensorsConfig, AscendModelSlimConfig) based on whether
the current hardware is an Ascend 310P device.
Implemented AscendW8A8LinearMethod310P: A dedicated linear quantization
method for 310P is provided, handling the specifics of weight and
activation quantization, including input parameter broadcasting and
weight data manipulation.
Extended AscendModelSlimConfig for 310P: A specialized configuration
class for 310P integrates the new W8A8 linear method for both standard
linear layers and vocabulary parallel embeddings, ensuring proper
quantization application.

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: Tflowers-0129 <2906339855@qq.com>
Signed-off-by: Shaoxu Cheng <2906339855@qq.com>
2026-02-03 14:13:06 +08:00

108 lines
3.7 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import Any
import torch
import torch_npu
from vllm_ascend.quantization.methods.base import AscendLinearScheme
from .registry import register_scheme
@register_scheme("W8A8", "linear")
class AscendW8A8LinearMethod310P(AscendLinearScheme):
"""310P-only W8A8 static linear scheme.
Notes:
- This scheme is discovered via 310P local registry.
"""
def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.float16,
) -> dict[str, Any]:
return {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
return {
"input_scale": torch.empty(1, dtype=params_dtype),
"input_offset": torch.empty(1, dtype=torch.int8),
}
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params: dict[str, Any] = {}
params["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
# NOTE: keep identical to your current working behavior.
if params_dtype == torch.bfloat16:
params["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
else:
params["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
params["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
x = torch.ops.vllm.quantize(
x,
layer.aclnn_input_scale,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
return torch_npu.npu_quant_matmul(
x,
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=layer.params_dtype,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor),
requires_grad=False,
)
layer.aclnn_input_scale_reciprocal = torch.nn.Parameter(
1.0 / layer.aclnn_input_scale.data,
requires_grad=False,
)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor),
requires_grad=False,
).to(layer.aclnn_input_scale.dtype)
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)