diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py index 278e58c0..31ba28c0 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py @@ -320,6 +320,9 @@ class FusionOp(DecodeMoeOps): self.gmm2_weight_scale_fp32 = [ weight.clone() for weight in gmm2_weight_scale.unbind(dim=0) ] + else: + self.gmm1_weight_scale_fp32 = [torch.ones(1).npu().to(gmm1_weight.dtype)] + self.gmm2_weight_scale_fp32 = [torch.ones(1).npu().to(gmm2_weight.dtype)] else: self.gmm1_weight = [gmm1_weight.clone()] self.gmm2_weight = [gmm2_weight.clone()] @@ -524,7 +527,7 @@ def run_once(local_rank_id, @torch.inference_mode() def test_dispatch_gmm_combine_decode_base(): - custom_kwargs = BASE_KWARGS + custom_kwargs = BASE_KWARGS.copy() custom_kwargs["batch_size"] = 32 custom_kwargs["ep_world_size"] = 8 custom_kwargs["moe_expert_num"] = 32 @@ -539,7 +542,7 @@ def test_dispatch_gmm_combine_decode_base(): @torch.inference_mode() def test_dispatch_gmm_combine_decode_with_mc2_mask(): - custom_kwargs = BASE_KWARGS + custom_kwargs = BASE_KWARGS.copy() custom_kwargs["with_mc2_mask"] = True ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values()) @@ -548,7 +551,7 @@ def test_dispatch_gmm_combine_decode_with_mc2_mask(): @torch.inference_mode() def test_dispatch_gmm_combine_decode_dynamic_eplb(): - custom_kwargs = BASE_KWARGS + custom_kwargs = BASE_KWARGS.copy() custom_kwargs["dynamic_eplb"] = True ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values())