[BugFix][310p][Cherry-pick] Handle null quantization config in ShardedStateLoader310&[Feature][310P] Support W8A8 dynamic linear method (#8296)
### What this PR does / why we need it? This PR implements the `AscendW8A8DynamicLinearMethod310` quantization scheme specifically for 310P hardware. It includes the logic for weight retrieval, per-channel parameter generation, and the application of dynamic quantization using NPU-specific kernels. Additionally, it updates `ShardedStateLoader310` to handle quantization configurations more robustly when generating parameter type maps. Feedback from the review identified two critical issues in the implementation: 1. The tensor squeezing logic in the `apply` method incorrectly handles 2D inputs, which may lead to shape mismatches in subsequent layers. 2. The weight tensor in `process_weights_after_loading` is transposed after being converted to the private NZ format; the transpose operation should be performed on the ND tensor before conversion to ensure correct physical layout. cherry-pick from : #7546 #7725 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? New unit tests were added in `tests/ut/_310p/quantization/test_w8a8_dynamic_310.py` to verify the quantization method, and `tests/ut/_310p/test_sharded_state_loader_310p.py` was updated to test the state loader changes. --------- Signed-off-by: csoulnd <daidaicurry@foxmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import get_ep_group
|
||||
|
||||
@@ -26,7 +27,8 @@ from vllm_ascend._310p.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
|
||||
from vllm_ascend.quantization.methods.base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
from vllm_ascend.utils import maybe_trans_nz
|
||||
|
||||
from .registry import register_scheme
|
||||
|
||||
@@ -154,3 +156,66 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
|
||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
|
||||
|
||||
|
||||
@register_scheme("W8A8_DYNAMIC", "linear")
|
||||
class AscendW8A8DynamicLinearMethod310(AscendLinearScheme):
|
||||
"""310P-only W8A8 dynamic linear scheme.
|
||||
|
||||
Notes:
|
||||
- This scheme is discovered via 310P local registry.
|
||||
"""
|
||||
|
||||
def get_weight(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype = torch.float16,
|
||||
) -> dict[str, Any]:
|
||||
return {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
||||
|
||||
def get_perchannel_param(
|
||||
self,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
) -> dict[str, Any]:
|
||||
params: dict[str, Any] = {}
|
||||
params["weight_scale"] = torch.empty(output_size, 1, dtype=torch.float32)
|
||||
params["weight_offset"] = torch.empty(output_size, 1, dtype=torch.float32)
|
||||
return params
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
tp_rank: int | None = 0,
|
||||
) -> torch.Tensor:
|
||||
# NOTE(310P):
|
||||
# - There is an accuracy issue currently, which is expected to be fixed in the next version.
|
||||
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
|
||||
need_unsqz = False
|
||||
if pertoken_scale.dim() == 2:
|
||||
need_unsqz = True
|
||||
quantized_x = quantized_x.squeeze(dim=1)
|
||||
pertoken_scale = pertoken_scale.squeeze(dim=1)
|
||||
|
||||
# NOTE(310P):
|
||||
# - Currently, W8A8 dynamic quantization supports only symmetric quantization.
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
quantized_x,
|
||||
layer.weight.data,
|
||||
layer.weight_scale,
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=bias,
|
||||
output_dtype=x.dtype,
|
||||
)
|
||||
if need_unsqz:
|
||||
output = output.unsqueeze(dim=1)
|
||||
return output
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# cast quantized weight tensors in NZ format for higher inference speed
|
||||
layer.weight.data = maybe_trans_nz(layer.weight.data).transpose(0, 1)
|
||||
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
||||
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
||||
|
||||
Reference in New Issue
Block a user