diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py index ba7bf6ef..ebd368a7 100644 --- a/tests/ut/quantization/test_w8a8_dynamic.py +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -59,3 +59,48 @@ class TestAscendW8A8FusedMoEMethod(TestBase): self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16) self.assertEqual(param_dict["w13_weight_scale"].shape, (self.num_experts, 2 * self.intermediate_size, 1)) + + def build_layer(self): + layer = torch.nn.Module() + layer.w13_weight = torch.nn.Parameter(torch.empty( + self.num_experts, + 2 * self.intermediate_size, + self.hidden_size, + dtype=torch.int8), + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(torch.empty( + self.num_experts, + self.hidden_size, + self.intermediate_size, + dtype=torch.int8), + requires_grad=False) + w13_weight_scale = torch.zeros( + (self.num_experts, 2 * self.intermediate_size, 1), + dtype=torch.float32) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + w13_weight_offset = torch.zeros( + (self.num_experts, 2 * self.intermediate_size, 1), + dtype=torch.float32) + layer.w13_weight_offset = torch.nn.Parameter(w13_weight_offset, + requires_grad=False) + w2_weight_scale = torch.zeros((self.num_experts, self.hidden_size, 1), + dtype=torch.float32) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + w2_weight_offset = torch.zeros((self.num_experts, self.hidden_size, 1), + dtype=torch.float32) + layer.w2_weight_offset = torch.nn.Parameter(w2_weight_offset, + requires_grad=False) + return layer + + @patch('torch_npu.npu_format_cast') + def test_process_weights_after_loading(self, mock_npu_format_cast): + + def func_by_args(weight, num_format): + return weight + + mock_npu_format_cast.side_effect = func_by_args + new_layer = self.build_layer() + self.quant_method.process_weights_after_loading(new_layer) + mock_npu_format_cast.assert_called() diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index da6d3a69..1c158d09 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -29,7 +29,7 @@ from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.ops.fused_moe.experts_selector import select_experts -from vllm_ascend.utils import maybe_trans_nz +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz class AscendW8A8DynamicLinearMethod: @@ -276,8 +276,12 @@ class AscendW8A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2).contiguous() - layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) - layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) + # TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant` + # can only support weight nz. + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(