diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 6db32e6..49a9ed2 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -50,6 +50,10 @@ def mock_dp_and_tp_group(mocker): return mock_group +def mock_npu_format_cast(weight_data, format): + return weight_data + + @pytest.fixture def mock_dist_env(mocker: MockerFixture): # init dist env patch @@ -310,12 +314,14 @@ class TestAscendUnquantizedFusedMoEMethod: layer.w13_weight.data = torch.randn(16, 32) layer.w2_weight.data = torch.randn(16, 32) - moe_method.process_weights_after_loading(layer) + with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \ + patch('vllm_ascend.utils.is_310p', return_value=False): + moe_method.process_weights_after_loading(layer) - assert isinstance(layer.w13_weight, torch.nn.Parameter) - assert isinstance(layer.w2_weight, torch.nn.Parameter) - assert not layer.w13_weight.requires_grad - assert not layer.w2_weight.requires_grad + assert isinstance(layer.w13_weight, torch.nn.Parameter) + assert isinstance(layer.w2_weight, torch.nn.Parameter) + assert not layer.w13_weight.requires_grad + assert not layer.w2_weight.requires_grad @pytest.mark.parametrize("others_param", [[256, 4], [128, 1], [128, 1], [128, 4]]) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f02d146..b8509f0 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -49,8 +49,8 @@ from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, - get_all_reduce_merge_state, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, + dispose_tensor, get_all_reduce_merge_state, get_ascend_soc_version, get_rm_router_logits_state, is_310p) @@ -866,6 +866,11 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w2_weight.data), requires_grad=False) + if not is_310p(): + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) def apply( self,