[Feat.]: support 310p w8a8 (#6454)
### What this PR does / why we need it?
Introduced 310P W8A8 Quantization Support: New modules and methods have
been added to enable W8A8 static quantization specifically for the
Ascend 310P platform.
Platform-Specific Quantization Configuration Loading: The system now
dynamically loads the appropriate quantization configurations
(AscendCompressedTensorsConfig, AscendModelSlimConfig) based on whether
the current hardware is an Ascend 310P device.
Implemented AscendW8A8LinearMethod310P: A dedicated linear quantization
method for 310P is provided, handling the specifics of weight and
activation quantization, including input parameter broadcasting and
weight data manipulation.
Extended AscendModelSlimConfig for 310P: A specialized configuration
class for 310P integrates the new W8A8 linear method for both standard
linear layers and vocabulary parallel embeddings, ensuring proper
quantization application.
- vLLM version: v0.14.1
- vLLM main:
dc917cceb8
---------
Signed-off-by: Tflowers-0129 <2906339855@qq.com>
Signed-off-by: Shaoxu Cheng <2906339855@qq.com>
This commit is contained in:
4
.github/workflows/_e2e_test.yaml
vendored
4
.github/workflows/_e2e_test.yaml
vendored
@@ -464,4 +464,6 @@ jobs:
|
||||
PYTORCH_NPU_ALLOC_CONF: max_split_size_mb:256
|
||||
VLLM_WORKER_MULTIPROC_METHOD: spawn
|
||||
run: |
|
||||
pytest -sv --durations=0 tests/e2e/310p/test_offline_inference_parallel_310p.py
|
||||
pytest -sv --durations=0 \
|
||||
tests/e2e/310p/test_offline_inference_parallel_310p.py \
|
||||
tests/e2e/310p/test_offline_inference_w8a8_310p.py
|
||||
|
||||
22
tests/e2e/310p/test_offline_inference_w8a8_310p.py
Normal file
22
tests/e2e/310p/test_offline_inference_w8a8_310p.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
def test_qwen3_w8a8_e2e_310p(dtype: str, max_tokens: int) -> None:
|
||||
example_prompts = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
]
|
||||
|
||||
with VllmRunner(
|
||||
"vllm-ascend/Qwen3-32B-W8A8",
|
||||
tensor_parallel_size=4,
|
||||
dtype=dtype,
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
quantization="ascend",
|
||||
enable_prefix_caching=False,
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
22
vllm_ascend/_310p/quantization/__init__.py
Normal file
22
vllm_ascend/_310p/quantization/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from vllm_ascend._310p.quantization.modelslim_config import AscendModelSlimConfig310
|
||||
|
||||
__all__ = [
|
||||
"AscendModelSlimConfig310",
|
||||
]
|
||||
22
vllm_ascend/_310p/quantization/methods/__init__.py
Normal file
22
vllm_ascend/_310p/quantization/methods/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from . import w8a8_static # noqa: F401
|
||||
|
||||
# Future extensions:
|
||||
# from . import w8a8_dynamic # noqa: F401
|
||||
# from . import w4a16 # noqa: F401
|
||||
41
vllm_ascend/_310p/quantization/methods/registry.py
Normal file
41
vllm_ascend/_310p/quantization/methods/registry.py
Normal file
@@ -0,0 +1,41 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Any
|
||||
|
||||
# 310P-local registry: maps (quant_type, layer_type) -> SchemeClass
|
||||
_SCHEME_REGISTRY: dict[tuple[str, str], type[Any]] = {}
|
||||
|
||||
|
||||
def register_scheme(quant_type: str, layer_type: str):
|
||||
"""Decorator to register a 310P quantization scheme."""
|
||||
|
||||
def decorator(cls: type[Any]) -> type[Any]:
|
||||
key = (quant_type, layer_type)
|
||||
if key in _SCHEME_REGISTRY:
|
||||
raise ValueError(
|
||||
f"[310P] Scheme already registered for {quant_type}/{layer_type}: {_SCHEME_REGISTRY[key].__name__}"
|
||||
)
|
||||
_SCHEME_REGISTRY[key] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_scheme_class(quant_type: str, layer_type: str) -> type[Any] | None:
|
||||
"""Get 310P scheme class for given quant_type and layer_type."""
|
||||
return _SCHEME_REGISTRY.get((quant_type, layer_type))
|
||||
107
vllm_ascend/_310p/quantization/methods/w8a8_static.py
Normal file
107
vllm_ascend/_310p/quantization/methods/w8a8_static.py
Normal file
@@ -0,0 +1,107 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.quantization.methods.base import AscendLinearScheme
|
||||
|
||||
from .registry import register_scheme
|
||||
|
||||
|
||||
@register_scheme("W8A8", "linear")
|
||||
class AscendW8A8LinearMethod310P(AscendLinearScheme):
|
||||
"""310P-only W8A8 static linear scheme.
|
||||
|
||||
Notes:
|
||||
- This scheme is discovered via 310P local registry.
|
||||
"""
|
||||
|
||||
def get_weight(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.float16,
|
||||
) -> dict[str, Any]:
|
||||
return {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||
|
||||
def get_pertensor_param(self, params_dtype: torch.dtype) -> dict[str, Any]:
|
||||
return {
|
||||
"input_scale": torch.empty(1, dtype=params_dtype),
|
||||
"input_offset": torch.empty(1, dtype=torch.int8),
|
||||
}
|
||||
|
||||
def get_perchannel_param(self, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {}
|
||||
params["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
||||
|
||||
# NOTE: keep identical to your current working behavior.
|
||||
if params_dtype == torch.bfloat16:
|
||||
params["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
||||
else:
|
||||
params["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
||||
|
||||
params["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
params["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
||||
return params
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
tp_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
if x.dtype != torch.int8:
|
||||
x = torch.ops.vllm.quantize(
|
||||
x,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
|
||||
return torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
expanding_factor = layer.weight.data.shape[1]
|
||||
layer.aclnn_input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data.repeat(expanding_factor),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.aclnn_input_scale_reciprocal = torch.nn.Parameter(
|
||||
1.0 / layer.aclnn_input_scale.data,
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False,
|
||||
).to(layer.aclnn_input_scale.dtype)
|
||||
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
162
vllm_ascend/_310p/quantization/modelslim_config.py
Normal file
162
vllm_ascend/_310p/quantization/modelslim_config.py
Normal file
@@ -0,0 +1,162 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
UnquantizedEmbeddingMethod,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
|
||||
# Important: trigger 310P method registrations (register into 310P-local registry)
|
||||
from vllm_ascend._310p.quantization import methods as _methods_310p # noqa: F401
|
||||
from vllm_ascend._310p.quantization.methods.registry import get_scheme_class as get_scheme_class_310p
|
||||
from vllm_ascend.quantization.method_adapters import (
|
||||
AscendLinearMethod,
|
||||
)
|
||||
from vllm_ascend.quantization.modelslim_config import (
|
||||
AscendModelSlimConfig,
|
||||
packed_modules_model_mapping,
|
||||
)
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def create_scheme_for_layer_310p(
|
||||
cfg: AscendModelSlimConfig,
|
||||
quant_description: dict[str, Any],
|
||||
prefix: str,
|
||||
layer_type: str,
|
||||
packed_modules_mapping: dict[str, Any] | None = None,
|
||||
):
|
||||
"""Create 310P quant scheme (mainline-like).
|
||||
|
||||
- If quant_type cannot be determined: raise ValueError
|
||||
- If quant_type is determined but not supported on 310P: raise NotImplementedError
|
||||
"""
|
||||
logger.info_once("Using 310P ModelSlim Quantization routing.")
|
||||
|
||||
if layer_type != "linear":
|
||||
raise NotImplementedError(f"310P quantization: layer_type={layer_type} is not supported yet (TODO).")
|
||||
|
||||
quant_type = cfg._get_linear_quant_type(prefix)
|
||||
if quant_type is None:
|
||||
raise ValueError(f"310P quantization: could not determine quant_type for layer={prefix}.")
|
||||
|
||||
scheme_cls = get_scheme_class_310p(quant_type, "linear")
|
||||
if scheme_cls is None:
|
||||
raise NotImplementedError(f"310P quantization: quant_type={quant_type} for linear is not supported yet (TODO).")
|
||||
|
||||
return scheme_cls()
|
||||
|
||||
|
||||
@register_quantization_config(ASCEND_QUANTIZATION_METHOD)
|
||||
class AscendModelSlimConfig310(AscendModelSlimConfig):
|
||||
"""310P override for ModelSlim quantization config.
|
||||
|
||||
- Uses 310P-local scheme registry to create scheme by (quant_type, layer_type).
|
||||
- MUST keep packed_modules_mapping behavior consistent with base, otherwise
|
||||
fused modules (qkv_proj / gate_up_proj) will miss and fallback to base,
|
||||
causing NZ/transpose issues on 310P.
|
||||
"""
|
||||
|
||||
def _get_linear_quant_type(self, prefix: str) -> str | None:
|
||||
"""Packed-aware quant type lookup.
|
||||
|
||||
ModelSlim may describe fused modules by their shards.
|
||||
Example:
|
||||
prefix = "...qkv_proj" -> shards "...q_proj.weight", "...k_proj.weight", "...v_proj.weight"
|
||||
"""
|
||||
fused_mapping = getattr(self, "packed_modules_mapping", {}) or {}
|
||||
proj_name = prefix.split(".")[-1]
|
||||
|
||||
if proj_name in fused_mapping:
|
||||
shard_prefixes = [
|
||||
prefix.replace(proj_name, shard_proj_name) for shard_proj_name in fused_mapping[proj_name]
|
||||
]
|
||||
quant_types: list[str] = []
|
||||
for sp in shard_prefixes:
|
||||
qt = self.quant_description.get(sp + ".weight")
|
||||
if isinstance(qt, str):
|
||||
quant_types.append(qt)
|
||||
|
||||
if not quant_types:
|
||||
return None
|
||||
|
||||
first = quant_types[0]
|
||||
if any(q != first for q in quant_types[1:]):
|
||||
raise ValueError(
|
||||
f"310P quantization: not all shards of fused layer '{prefix}' "
|
||||
f"share the same quant type. shards={shard_prefixes}, types={quant_types}"
|
||||
)
|
||||
return first
|
||||
|
||||
qt = self.quant_description.get(prefix + ".weight")
|
||||
return qt if isinstance(qt, str) else None
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> QuantizeMethodBase | None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
model_type = vllm_config.model_config.hf_config.model_type
|
||||
|
||||
if model_type in packed_modules_model_mapping:
|
||||
self.packed_modules_mapping = packed_modules_model_mapping[model_type]
|
||||
|
||||
prefix = self.quant_prefix_mapper(model_type, prefix)
|
||||
if prefix.startswith("language_model"):
|
||||
prefix = prefix.split(".", 1)[-1]
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
packed = getattr(self, "packed_modules_mapping", {})
|
||||
if self.is_layer_skipped_ascend(prefix, packed):
|
||||
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
|
||||
|
||||
return AscendUnquantizedLinearMethod()
|
||||
|
||||
scheme = create_scheme_for_layer_310p(
|
||||
cfg=self,
|
||||
quant_description=self.quant_description,
|
||||
prefix=prefix,
|
||||
layer_type="linear",
|
||||
packed_modules_mapping=packed,
|
||||
)
|
||||
return AscendLinearMethod(scheme)
|
||||
|
||||
if isinstance(layer, VocabParallelEmbedding):
|
||||
return UnquantizedEmbeddingMethod()
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
raise NotImplementedError(
|
||||
"310P quantization: FusedMoE is not supported yet. "
|
||||
"TODO: add 310P MoE quant schemes and routing. "
|
||||
"Workaround: use a non-MoE model."
|
||||
)
|
||||
|
||||
return super().get_quant_method(layer, prefix)
|
||||
@@ -150,7 +150,10 @@ class NPUPlatform(Platform):
|
||||
if ASCEND_QUANTIZATION_METHOD not in quant_action.choices:
|
||||
quant_action.choices.append(ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
from vllm_ascend.quantization import AscendCompressedTensorsConfig, AscendModelSlimConfig # noqa: F401
|
||||
if not is_310p():
|
||||
from vllm_ascend.quantization import AscendCompressedTensorsConfig, AscendModelSlimConfig # noqa: F401
|
||||
else:
|
||||
from vllm_ascend._310p.quantization import AscendModelSlimConfig310 # noqa: F401
|
||||
|
||||
config_deprecated_logging()
|
||||
|
||||
|
||||
@@ -138,24 +138,13 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
if ascend_quant_method == COMPRESSED_TENSORS_METHOD:
|
||||
quant_bias = bias
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType._310P:
|
||||
# On 300I Duo platform, we need transpose again if
|
||||
# using nz. This transpose can be skipped in torchair.
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight.data.transpose(1, 0),
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
else:
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.deq_scale,
|
||||
bias=quant_bias,
|
||||
output_dtype=layer.params_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
@@ -169,8 +158,8 @@ class AscendW8A8LinearMethod(AscendLinearScheme):
|
||||
layer.aclnn_input_offset = torch.nn.Parameter(
|
||||
layer.input_offset.data.repeat(expanding_factor),
|
||||
requires_grad=False).to(layer.aclnn_input_scale.dtype)
|
||||
if get_ascend_device_type() != AscendDeviceType._310P:
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
|
||||
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
||||
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
||||
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
||||
|
||||
Reference in New Issue
Block a user