[Quantization]300I Duo support w8a8 quantization (#1560)
### What this PR does / why we need it? This pr supports w8a8 on 300I Duo platform. The main change is to use `npu_quant_grouped_matmul_dequant` to replace `npu_grouped_matmul`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? offline inference on 310p runs normally. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Signed-off-by: tianyitang <tangtianyi4@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: tianyitang <tangtianyi4@huawei.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from torch_npu.npu.streams import Event
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
try:
|
||||
# Recent release of torchair has moved these ops to `.scope`.
|
||||
@@ -175,6 +176,28 @@ def aligned_16(tensor: torch.Tensor):
|
||||
return new_tensor
|
||||
|
||||
|
||||
def maybe_converting_weight_acl_format(model, format=ACL_FORMAT_FRACTAL_NZ):
|
||||
# currently, there are some operations which do not support ACL_FORMAT_FRACTAL_NZ
|
||||
# in eager mode but support it in torchair graph mode. since ACL_FORMAT_FRACTAL_NZ
|
||||
# is much more preferred than ACL_FORMAT_FRACTAL_ND on 300I Duo, we add this
|
||||
# conversion when using torchair graph mode on 300I Duo platform.
|
||||
# TODO: we will remove this conversion if npu_quant_grouped_matmul_dequant
|
||||
# accepts weight format of ACL_FORMAT_FRACTAL_NZ in eager mode.
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
|
||||
use_torchair = get_ascend_config().torchair_graph_config.enabled
|
||||
if not is_310p() or not use_torchair:
|
||||
return
|
||||
for module in model.modules():
|
||||
if isinstance(module, FusedMoE):
|
||||
if torch_npu.get_npu_format(module.w13_weight.data) == format:
|
||||
return
|
||||
module.w13_weight.data = torch_npu.npu_format_cast(
|
||||
module.w13_weight.data, format)
|
||||
module.w2_weight.data = torch_npu.npu_format_cast(
|
||||
module.w2_weight.data, format)
|
||||
|
||||
|
||||
def try_register_lib(lib_name: str, lib_info: str = ""):
|
||||
import importlib
|
||||
import importlib.util
|
||||
|
||||
Reference in New Issue
Block a user