diff --git a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py index a4a61a1..ce7970c 100644 --- a/tests/ut/ops/test_fused_moe_prepare_and_finalize.py +++ b/tests/ut/ops/test_fused_moe_prepare_and_finalize.py @@ -8,6 +8,7 @@ from vllm_ascend.ops.moe.fused_moe_prepare_and_finalize import ( FusedMoEPrepareAndFinalizeWithAll2All, FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, FusedMoEPrepareAndFinalizeWithNaiveMulticast) +from vllm_ascend.utils import vllm_version_is class TestFusedMoEPrepareAndFinalize(unittest.TestCase): @@ -230,8 +231,12 @@ class TestFusedMoEPrepareAndFinalize(unittest.TestCase): mock_get_dp_group): # Mock forward context with DP metadata mock_context = MagicMock() - mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor( - [2, 5, 7]) + if vllm_version_is("0.10.2"): + mock_context.dp_metadata.cu_tokens_across_dp_cpu = torch.tensor( + [2, 5, 7]) + else: + mock_context.dp_metadata.cu_tokens_across_sp.return_value = torch.tensor( + [2, 5, 7]) mock_get_forward_context.return_value = mock_context # Setup DP group mock diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 19c6c96..a5bdfe2 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -28,7 +28,7 @@ from vllm_ascend.ops.fused_moe import (AscendFusedMoE, AscendUnquantizedFusedMoEMethod) from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp -from vllm_ascend.utils import AscendSocVersion, adapt_patch +from vllm_ascend.utils import AscendSocVersion, adapt_patch, vllm_version_is adapt_patch(True) @@ -93,14 +93,18 @@ def mock_dist_env(mocker: MockerFixture): mock_moe_comm_method.finalize.side_effect = mock_finalize - mock_forward_context_obj = MagicMock( - moe_comm_method=mock_moe_comm_method, - moe_comm_type=MoECommType.MC2, - max_tokens_across_dp=10, - dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]), - mc2_mask=torch.zeros(16, dtype=torch.bool), - padded_num_tokens=16, - with_quant=False) + if vllm_version_is("0.10.2"): + dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10]) + else: + dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) + mock_forward_context_obj = MagicMock(moe_comm_method=mock_moe_comm_method, + moe_comm_type=MoECommType.MC2, + max_tokens_across_dp=10, + dp_metadata=dp_metadata, + mc2_mask=torch.zeros( + 16, dtype=torch.bool), + padded_num_tokens=16, + with_quant=False) with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ diff --git a/tests/ut/torchair/ops/test_torchair_fused_moe.py b/tests/ut/torchair/ops/test_torchair_fused_moe.py index a550a67..fb1cd81 100644 --- a/tests/ut/torchair/ops/test_torchair_fused_moe.py +++ b/tests/ut/torchair/ops/test_torchair_fused_moe.py @@ -26,7 +26,8 @@ from vllm_ascend.ascend_forward_context import _get_fused_moe_state from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod from vllm_ascend.torchair.ops.torchair_fused_moe import ( TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod) -from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 +from vllm_ascend.utils import adapt_patch # noqa E402 +from vllm_ascend.utils import AscendSocVersion, vllm_version_is adapt_patch(True) @@ -53,6 +54,10 @@ def mock_dp_and_tp_group(mocker): @pytest.fixture def mock_dist_env(mocker: MockerFixture): # init dist env patch + if vllm_version_is("0.10.2"): + dp_metadata = MagicMock(cu_tokens_across_dp_cpu=[5, 10]) + else: + dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5]) with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ @@ -80,7 +85,7 @@ def mock_dist_env(mocker: MockerFixture): patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context', return_value=MagicMock( max_tokens_across_dp=10, - dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]) + dp_metadata=dp_metadata, )), \ patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config', return_value=MagicMock( diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index 711e291..bc0a04e 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -47,6 +47,7 @@ from vllm.model_executor.models.utils import ( make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.utils import vllm_version_is class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -169,9 +170,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp") else: - self.mlp = Qwen3MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if vllm_version_is("0.10.2"): + self.mlp = Qwen3MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py index 6ed9858..3d800e4 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -26,6 +26,8 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig +from vllm_ascend.utils import vllm_version_is + class FusedMoEPrepareAndFinalize(ABC): """ @@ -414,8 +416,12 @@ class FusedMoEPrepareAndFinalizeWithNaiveMulticast(FusedMoEPrepareAndFinalize): self.enable_shared_expert_dp = enable_shared_expert_dp if self.moe_config.dp_size > 1: - self.cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + if vllm_version_is("0.10.2"): + self.cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + else: + self.cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_sp(1) hidden_states = self._naive_multicast(hidden_states, self.cu_tokens_across_dp_cpu) if rm_router_logits: diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py index 8093ad4..c6aad6a 100644 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -56,6 +56,7 @@ from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.torchair.ops.sequence_parallel import (MetadataForPadding, init_metadata_for_sp) +from vllm_ascend.utils import vllm_version_is class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): @@ -311,9 +312,14 @@ class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): quant_config=quant_config, prefix=f"{prefix}.mlp") else: - self.mlp = Qwen3MoeSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if vllm_version_is("0.10.2"): + self.mlp = Qwen3MoeSparseMoeBlock( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(vllm_config=vllm_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 967aa03..bd25a79 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -1242,8 +1242,12 @@ class TorchairAscendFusedMoE(FusedMoE): router_logits = get_dp_group().all_gather(router_logits, 0) elif fused_moe_state == FusedMoEState.NaiveMulticast: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + if vllm_version_is("0.10.2"): + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_dp_cpu + else: + cu_tokens_across_dp_cpu = get_forward_context( + ).dp_metadata.cu_tokens_across_sp(1) hidden_states = self.naive_multicast(hidden_states, cu_tokens_across_dp_cpu) if self.rm_router_logits: