From 363febb6cb0fb2db3937ccc0acb635adb3a55797 Mon Sep 17 00:00:00 2001 From: wangxiaoteng888 <56506195+wangxiaoteng888@users.noreply.github.com> Date: Sat, 18 Apr 2026 18:06:42 +0800 Subject: [PATCH] [BugFix][v0.18.0] Gate recompute/balance/fused_mc2 by PD mode (#8374) ### What this PR does / why we need it? - Enforce recompute scheduler only in PD-disaggregated mode. - Enforce balance scheduling only in PD-mixed mode. - Enforce fused MC2 only on PD-disaggregated D-side (kv_consumer). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By ci --------- Signed-off-by: wangxiaoteng --- .../configuration/additional_config.md | 2 +- tests/ut/test_platform.py | 251 ++++++++++++++++++ vllm_ascend/ascend_config.py | 1 + vllm_ascend/envs.py | 8 +- vllm_ascend/platform.py | 29 ++ 5 files changed, 288 insertions(+), 3 deletions(-) diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index e96e7592..25e396a1 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -37,7 +37,7 @@ The following table lists additional configuration options available in vLLM Asc | `enable_shared_expert_dp` | bool | `False` | When the expert is shared in DP, it delivers better performance but consumes more memory. Currently only DeepSeek series models are supported. | | `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multi-stream shared expert. This option only takes effect on MoE models with shared experts. | | `multistream_overlap_gate` | bool | `False` | Whether to enable multi-stream overlap gate. This option only takes effect on MoE models with shared experts. | -| `recompute_scheduler_enable` | bool | `False` | Whether to enable recompute scheduler. | +| `recompute_scheduler_enable` | bool | `False` | Whether to enable the recompute scheduler. **Only valid in PD-disaggregated mode** (`kv_role` is `kv_producer` or `kv_consumer`). **Do not enable in PD-mixed mode** (no `kv_transfer_config`, or `kv_role` is `kv_both`); startup will fail with a clear error. | | `enable_cpu_binding` | bool | `True` | Whether to enable CPU binding. Only takes effect on ARM CPUs; A3 uses the global-slicing CPU allocation strategy and other device types use the topo-affinity CPU allocation strategy. | | `SLO_limits_for_dynamic_batch` | int | `-1` | SLO limits for dynamic batch. This is new scheduler to support dynamic batch feature | | `enable_npugraph_ex` | bool | `False` | Whether to enable npugraph_ex graph mode. | diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 9ca4d6c2..d5a8d7cd 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -418,6 +418,257 @@ class TestNPUPlatform(TestBase): self.assertEqual(vllm_config.cache_config.block_size, 128) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_recompute_scheduler_rejects_pd_mixed_no_kv_transfer( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = True + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = None + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with pytest.raises(ValueError, match=r"recompute_scheduler_enable.*PD-disaggregated.*PD-mixed"): + with patch.object(platform.NPUPlatform, "_fix_incompatible_config"): + self.platform.check_and_update_config(vllm_config) + + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_recompute_scheduler_rejects_pd_mixed_kv_both( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = True + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = MagicMock(kv_role="kv_both", engine_id="engine0") + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with pytest.raises(ValueError, match=r"recompute_scheduler_enable.*PD-disaggregated.*PD-mixed"): + with ( + patch.object(platform.NPUPlatform, "_fix_incompatible_config"), + patch.object(platform, "check_kv_extra_config"), + ): + self.platform.check_and_update_config(vllm_config) + + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_balance_scheduler_rejects_pd_disaggregated_kv_producer( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = False + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = MagicMock(kv_role="kv_producer", engine_id="engine0") + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_BALANCE_SCHEDULING", True, create=True): + with pytest.raises(ValueError, match=r"VLLM_ASCEND_BALANCE_SCHEDULING.*PD-mixed.*PD-disaggregated"): + with ( + patch.object(platform.NPUPlatform, "_fix_incompatible_config"), + patch.object(platform, "check_kv_extra_config"), + ): + self.platform.check_and_update_config(vllm_config) + + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_balance_scheduler_rejects_pd_disaggregated_kv_consumer( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = False + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = MagicMock(kv_role="kv_consumer", engine_id="engine0") + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_BALANCE_SCHEDULING", True, create=True): + with pytest.raises(ValueError, match=r"VLLM_ASCEND_BALANCE_SCHEDULING.*PD-mixed.*PD-disaggregated"): + with ( + patch.object(platform.NPUPlatform, "_fix_incompatible_config"), + patch.object(platform, "check_kv_extra_config"), + ): + self.platform.check_and_update_config(vllm_config) + + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_fused_mc2_rejects_pd_mixed_no_kv_transfer( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = False + mock_ascend_config.enable_mc2_hierarchy_comm = False + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = None + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True): + with pytest.raises(ValueError, match=r"VLLM_ASCEND_ENABLE_FUSED_MC2.*kv_role='kv_consumer'.*PD-mixed"): + with patch.object(platform.NPUPlatform, "_fix_incompatible_config"): + self.platform.check_and_update_config(vllm_config) + + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_fused_mc2_rejects_pd_mixed_kv_both( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = False + mock_ascend_config.enable_mc2_hierarchy_comm = False + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = MagicMock(kv_role="kv_both", engine_id="engine0") + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True): + with pytest.raises(ValueError, match=r"VLLM_ASCEND_ENABLE_FUSED_MC2.*kv_role='kv_consumer'.*kv_role='kv_both'"): + with ( + patch.object(platform.NPUPlatform, "_fix_incompatible_config"), + patch.object(platform, "check_kv_extra_config"), + ): + self.platform.check_and_update_config(vllm_config) + + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_fused_mc2_rejects_pd_disaggregated_kv_producer( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = False + mock_ascend_config.enable_mc2_hierarchy_comm = False + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = MagicMock(kv_role="kv_producer", engine_id="engine0") + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True): + with pytest.raises(ValueError, match=r"VLLM_ASCEND_ENABLE_FUSED_MC2.*kv_role='kv_consumer'.*kv_role='kv_producer'"): + with ( + patch.object(platform.NPUPlatform, "_fix_incompatible_config"), + patch.object(platform, "check_kv_extra_config"), + ): + self.platform.check_and_update_config(vllm_config) + + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") + @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) + @patch("vllm_ascend.ascend_config.init_ascend_config") + @patch("vllm_ascend.core.recompute_scheduler.RecomputeSchedulerConfig.initialize_from_config") + def test_check_and_update_config_fused_mc2_allows_pd_disaggregated_kv_consumer( + self, mock_init_recompute, mock_init_ascend, mock_soc_version, mock_auto_detect + ): + mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config() + mock_ascend_config.recompute_scheduler_enable = False + mock_ascend_config.enable_mc2_hierarchy_comm = False + mock_init_ascend.return_value = mock_ascend_config + + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.kv_transfer_config = MagicMock(kv_role="kv_consumer", engine_id="engine0") + vllm_config.parallel_config.decode_context_parallel_size = 1 + vllm_config.parallel_config.prefill_context_parallel_size = 1 + vllm_config.parallel_config.tensor_parallel_size = 1 + vllm_config.scheduler_config = MagicMock() + mock_init_recompute.return_value = MagicMock() + + from vllm_ascend import platform + + importlib.reload(platform) + self.platform = platform.NPUPlatform() + + with patch("vllm_ascend.platform.envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2", 1, create=True): + with ( + patch.object(platform.NPUPlatform, "_fix_incompatible_config"), + patch.object(platform, "check_kv_extra_config"), + ): + self.platform.check_and_update_config(vllm_config) + def test_update_block_size_for_backend_preserves_hybrid_block_size(self): vllm_config = TestNPUPlatform.mock_vllm_config() vllm_config.model_config.is_hybrid = True diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index a18e6549..a92a1c36 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -86,6 +86,7 @@ class AscendConfig: ) self.multistream_overlap_shared_expert = additional_config.get("multistream_overlap_shared_expert", False) self.multistream_overlap_gate = additional_config.get("multistream_overlap_gate", False) + # PD-disaggregated only (kv_producer/kv_consumer); invalid in PD-mixed (kv_both / no kv_transfer_config). self.recompute_scheduler_enable = additional_config.get("recompute_scheduler_enable", False) self.enable_cpu_binding = additional_config.get("enable_cpu_binding", True) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index b161220e..de408d36 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -93,7 +93,9 @@ env_variables: dict[str, Callable[[], Any]] = { "VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_CONTEXT_PARALLEL", "0"))), # Whether to anbale dynamic EPLB "DYNAMIC_EPLB": lambda: os.getenv("DYNAMIC_EPLB", "false").lower(), - # Whether to enable fused mc2(`dispatch_gmm_combine_decode`/`dispatch_ffn_combine` operator) + # Whether to enable fused MC2 (`dispatch_gmm_combine_decode` / `dispatch_ffn_combine`). + # Platform validation: only PD-disaggregated **decode** instances (`kv_role='kv_consumer'`). + # Not supported in PD-mixed mode (`kv_both` or no kv_transfer_config) or on prefill nodes (`kv_producer`). # 0, or not set: default ALLTOALL and MC2 will be used. # 1: ALLTOALL and MC2 might be replaced by `dispatch_ffn_combine` operator. # `dispatch_ffn_combine` can be used only for moe layer with W8A8, EP<=32, non-mtp, non-dynamic-eplb. @@ -101,7 +103,9 @@ env_variables: dict[str, Callable[[], Any]] = { # `dispatch_gmm_combine_decode` can be used only for **decode node** moe layer # with W8A8. And MTP layer must be W8A8. "VLLM_ASCEND_ENABLE_FUSED_MC2": lambda: int(os.getenv("VLLM_ASCEND_ENABLE_FUSED_MC2", "0")), - # Whether to anbale balance scheduling + # Whether to enable balance scheduling in the v1 scheduler. + # Platform validation: only PD-mixed mode (`kv_role='kv_both'` or no kv_transfer_config). + # Not supported in PD-disaggregated mode (`kv_producer` / `kv_consumer` only). "VLLM_ASCEND_BALANCE_SCHEDULING": lambda: bool(int(os.getenv("VLLM_ASCEND_BALANCE_SCHEDULING", "0"))), # use fused op transpose_kv_cache_by_block, default is True "VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK": lambda: bool( diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index a92922f5..23df1669 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -448,7 +448,36 @@ class NPUPlatform(Platform): if get_ascend_device_type() != AscendDeviceType._310P: compilation_config.custom_ops = ["all"] + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2: + kv_transfer_config = vllm_config.kv_transfer_config + kv_role = getattr(kv_transfer_config, "kv_role", None) + if kv_transfer_config is None or kv_role != "kv_consumer": + raise ValueError( + "VLLM_ASCEND_ENABLE_FUSED_MC2 (fused mc2) only supports PD-disaggregated " + "decode nodes (D-side) with kv_role='kv_consumer'. It is not supported " + "in PD-mixed mode (no kv_transfer_config / kv_role='kv_both') nor on " + "prefill nodes (P-side) with kv_role='kv_producer'." + ) + + if envs_ascend.VLLM_ASCEND_BALANCE_SCHEDULING: + kv_transfer_config = vllm_config.kv_transfer_config + kv_role = getattr(kv_transfer_config, "kv_role", None) + if kv_transfer_config is not None and kv_role != "kv_both": + raise ValueError( + "VLLM_ASCEND_BALANCE_SCHEDULING (balance scheduling) only supports PD-mixed mode " + "(kv_role='kv_both' or no kv_transfer_config), and is not supported in " + "PD-disaggregated mode (kv_role='kv_producer'/'kv_consumer')." + ) + if ascend_config.recompute_scheduler_enable: + kv_transfer_config = vllm_config.kv_transfer_config + kv_role = getattr(kv_transfer_config, "kv_role", None) + if kv_transfer_config is None or kv_role == "kv_both": + raise ValueError( + "recompute_scheduler_enable can only be enabled in PD-disaggregated mode " + "(kv_role='kv_producer' or 'kv_consumer'), and is not supported in PD-mixed mode." + ) + from vllm_ascend.core.recompute_scheduler import RecomputeSchedulerConfig recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(vllm_config)