Files
xc-llm-ascend/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py
csoulnd 8952fddc7e [BugFix][310p][Cherry-pick] Handle null quantization config in ShardedStateLoader310&[Feature][310P] Support W8A8 dynamic linear method (#8296)
### What this PR does / why we need it?
This PR implements the `AscendW8A8DynamicLinearMethod310` quantization
scheme specifically for 310P hardware. It includes the logic for weight
retrieval, per-channel parameter generation, and the application of
dynamic quantization using NPU-specific kernels. Additionally, it
updates `ShardedStateLoader310` to handle quantization configurations
more robustly when generating parameter type maps.

Feedback from the review identified two critical issues in the
implementation:
1. The tensor squeezing logic in the `apply` method incorrectly handles
2D inputs, which may lead to shape mismatches in subsequent layers.
2. The weight tensor in `process_weights_after_loading` is transposed
after being converted to the private NZ format; the transpose operation
should be performed on the ND tensor before conversion to ensure correct
physical layout.

cherry-pick from : #7546 #7725
### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
New unit tests were added in
`tests/ut/_310p/quantization/test_w8a8_dynamic_310.py` to verify the
quantization method, and
`tests/ut/_310p/test_sharded_state_loader_310p.py` was updated to test
the state loader changes.

---------

Signed-off-by: csoulnd <daidaicurry@foxmail.com>
2026-04-16 16:53:39 +08:00

222 lines
8.5 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from collections.abc import Callable
from typing import Any
import torch
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.methods.base import AscendLinearScheme, AscendMoEScheme, QuantType
from vllm_ascend.utils import maybe_trans_nz
from .registry import register_scheme
@register_scheme("W8A8_DYNAMIC", "moe")
class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
"""310P-only FusedMoE method for Ascend W8A8_DYNAMIC.
Notes:
- This scheme is discovered via 310P local registry.
"""
# Declare the quantization type for this scheme
quant_type: QuantType = QuantType.W8A8
def __init__(self):
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
self.in_dtype = vllm_config.model_config.dtype
def get_weight(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
# Fused gate_up_proj (column parallel)
param_dict["w13_weight"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.int8
)
# down_proj (row parallel)
param_dict["w2_weight"] = torch.empty(
num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.int8
)
return param_dict
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
)
param_dict["w13_weight_offset"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, 1, dtype=params_dtype
)
param_dict["w2_weight_scale"] = torch.empty(num_experts, hidden_sizes, 1, dtype=torch.float32)
param_dict["w2_weight_offset"] = torch.empty(num_experts, hidden_sizes, 1, dtype=params_dtype)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
expert_indices=topk_ids,
expert_scales=topk_weights,
num_experts=global_num_experts,
zero_expert_type=zero_expert_type,
hidden_states=x,
)
topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight,
w2=layer.w2_weight,
quant_type=self.quant_type,
dynamic_eplb=False,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
return final_hidden_states
def process_weights_after_loading(self, layer):
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
@register_scheme("W8A8_DYNAMIC", "linear")
class AscendW8A8DynamicLinearMethod310(AscendLinearScheme):
"""310P-only W8A8 dynamic linear scheme.
Notes:
- This scheme is discovered via 310P local registry.
"""
def get_weight(
self,
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.float16,
) -> dict[str, Any]:
return {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
def get_perchannel_param(
self,
output_size: int,
params_dtype: torch.dtype,
) -> dict[str, Any]:
params: dict[str, Any] = {}
params["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32)
params["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32)
return params
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
# NOTE(310P):
# - There is an accuracy issue currently, which is expected to be fixed in the next version.
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
need_unsqz = False
if pertoken_scale.dim() == 2:
need_unsqz = True
quantized_x = quantized_x.squeeze(dim=1)
pertoken_scale = pertoken_scale.squeeze(dim=1)
# NOTE(310P):
# - Currently, W8A8 dynamic quantization supports only symmetric quantization.
output = torch_npu.npu_quant_matmul(
quantized_x,
layer.weight.data,
layer.weight_scale,
pertoken_scale=pertoken_scale,
bias=bias,
output_dtype=x.dtype,
)
if need_unsqz:
output = output.unsqueeze(dim=1)
return output
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# cast quantized weight tensors in NZ format for higher inference speed
layer.weight.data = maybe_trans_nz(layer.weight.data).transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_offset.data = layer.weight_offset.data.flatten()