From 31186a3a9dee42b659d587236028d44ebe0d0ef7 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Sun, 12 Apr 2026 21:52:54 +0800 Subject: [PATCH] [BugFix] Add async communication check for capturing mode (#8149) ### What this PR does / why we need it? Introduce a check to not using asynchronous communication under `enable_dsa_cp_with_layer_shard` branch on capturing mode. This change prevents potential stream and event issues when operating in graph/capturing mode, ensuring safer communication practices. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? E2E test with dsv32 + FC1 + FULL_DECODE_ONLY + kv_transfer_config(kv_both) --------- Signed-off-by: chenchuw886 Co-authored-by: chenchuw886 --- docs/source/tutorials/models/DeepSeek-V3.2.md | 2 ++ docs/source/tutorials/models/GLM5.md | 1 + .../configuration/additional_config.md | 2 +- .../feature_guide/layer_sharding.md | 8 ++++- tests/ut/test_platform.py | 30 +++++++++++++++++++ tests/ut/test_utils.py | 25 ++++++++++++++++ vllm_ascend/platform.py | 16 ++++++++++ vllm_ascend/utils.py | 4 ++- 8 files changed, 85 insertions(+), 3 deletions(-) diff --git a/docs/source/tutorials/models/DeepSeek-V3.2.md b/docs/source/tutorials/models/DeepSeek-V3.2.md index 124e3757..65782c2f 100644 --- a/docs/source/tutorials/models/DeepSeek-V3.2.md +++ b/docs/source/tutorials/models/DeepSeek-V3.2.md @@ -161,6 +161,8 @@ vllm serve /root/.cache/modelscope/hub/models/vllm-ascend/DeepSeek-V3.2-W8A8 \ ``` +In PD-disaggregated deployments, `layer_sharding` is supported only on prefill/P nodes with `kv_role="kv_producer"`. Do not enable it on decode/D nodes or `kv_role="kv_both"` nodes. + ### Multi-node Deployment - `DeepSeek-V3.2-w8a8`: require at least 2 Atlas 800 A2 (64G × 8). diff --git a/docs/source/tutorials/models/GLM5.md b/docs/source/tutorials/models/GLM5.md index 31e8db3d..e2c2e66a 100644 --- a/docs/source/tutorials/models/GLM5.md +++ b/docs/source/tutorials/models/GLM5.md @@ -743,6 +743,7 @@ Before you start, please 2. prepare the script `run_dp_template.sh` on each node. To support a 200k context window on the stage of prefill, the parameter `"layer_sharding": ["q_b_proj"]` needs to be added to `--additional_config` on each prefill node. + In PD-disaggregated deployment, `layer_sharding` is supported only on prefill/P nodes with `kv_role="kv_producer"`; do not enable it on decode/D nodes or `kv_role="kv_both"` nodes. 1. Prefill node 0 ```shell diff --git a/docs/source/user_guide/configuration/additional_config.md b/docs/source/user_guide/configuration/additional_config.md index 0375fdd4..a6b10c56 100644 --- a/docs/source/user_guide/configuration/additional_config.md +++ b/docs/source/user_guide/configuration/additional_config.md @@ -43,7 +43,7 @@ The following table lists additional configuration options available in vLLM Asc | `enable_npugraph_ex` | bool | `False` | Whether to enable npugraph_ex graph mode. | | `pa_shape_list` | list | `[]` | The custom shape list of page attention ops. | | `enable_kv_nz` | bool | `False` | Whether to enable KV cache NZ layout. This option only takes effects on models using MLA (e.g., DeepSeek). | -| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear | +| `layer_sharding` | dict | `{}` | Configuration options for Layer Sharding Linear. In PD-disaggregated deployments, it is supported only on P nodes with `kv_role="kv_producer"`. | | `enable_sparse_c8` | bool | `False` | Whether to enable KV cache C8 in DSA models (e.g., DeepSeekV3.2 and GLM5). Not supported on A5 devices now | | `enable_mc2_hierarchy_comm` | bool | `False` | Enable dispatch/combine op inter-node communication by ROCE. | diff --git a/docs/source/user_guide/feature_guide/layer_sharding.md b/docs/source/user_guide/feature_guide/layer_sharding.md index 62fc3a2c..770a50cb 100644 --- a/docs/source/user_guide/feature_guide/layer_sharding.md +++ b/docs/source/user_guide/feature_guide/layer_sharding.md @@ -37,11 +37,15 @@ To enable **Layer Shard Linear**, specify the target linear layers using the `-- }' ``` +> **Restriction** +> In PD-disaggregated deployments, Layer Sharding can only be enabled on the **P node** with `kv_role="kv_producer"`. +> `kv_role="kv_consumer"` and `kv_role="kv_both"` are not supported. + --- ## Supported Scenarios -This feature can be enabled in any scenario, but delivers the greatest benefit in the following cases: +This feature delivers the greatest benefit in the following cases: ### FlashComm2-enabled @@ -62,6 +66,8 @@ vllm serve \ With [DSA-CP](https://github.com/vllm-project/vllm-ascend/pull/4702), both `q_b_proj` and `o_proj` layers require large weight matrices to be stored per layer. Sharding these layers across NPUs helps fit extremely deep models (e.g., 61-layer architectures) into limited device memory. +In PD-disaggregated deployments, this mode is supported only on the **P node** with `kv_role="kv_producer"`. + **Example configuration:** ```bash diff --git a/tests/ut/test_platform.py b/tests/ut/test_platform.py index 4256e5eb..9ca4d6c2 100644 --- a/tests/ut/test_platform.py +++ b/tests/ut/test_platform.py @@ -438,6 +438,36 @@ class TestNPUPlatform(TestBase): self.assertEqual(vllm_config.cache_config.block_size, 512) + def test_validate_layer_sharding_config_accepts_single_node(self): + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]} + vllm_config.kv_transfer_config = None + + self.platform._validate_layer_sharding_config(vllm_config) + + def test_validate_layer_sharding_config_accepts_kv_producer(self): + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]} + vllm_config.kv_transfer_config = MagicMock(is_kv_producer=True, kv_role="kv_producer") + + self.platform._validate_layer_sharding_config(vllm_config) + + def test_validate_layer_sharding_config_rejects_non_kv_producer(self): + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]} + vllm_config.kv_transfer_config = MagicMock(is_kv_producer=False, kv_role="kv_consumer") + + with pytest.raises(ValueError, match="layer_sharding is only supported on P nodes"): + self.platform._validate_layer_sharding_config(vllm_config) + + def test_validate_layer_sharding_config_rejects_kv_both(self): + vllm_config = TestNPUPlatform.mock_vllm_config() + vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]} + vllm_config.kv_transfer_config = MagicMock(is_kv_producer=True, kv_role="kv_both") + + with pytest.raises(ValueError, match="layer_sharding is only supported on P nodes"): + self.platform._validate_layer_sharding_config(vllm_config) + @patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization") @patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3) @patch("vllm_ascend.ascend_config.init_ascend_config") diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 6f4c2500..24711fbb 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -35,6 +35,7 @@ class TestUtils(TestBase): from vllm_ascend import platform importlib.reload(platform) + utils.enable_dsa_cp_with_layer_shard.cache_clear() def test_nd_to_nz_2d(self): # can be divided by 16 @@ -131,6 +132,30 @@ class TestUtils(TestBase): with mock.patch("torch.npu.current_stream") as mock_current_stream: self.assertEqual(utils.current_stream(), mock_current_stream()) + def test_enable_dsa_cp_with_layer_shard_accepts_kv_producer(self): + mock_vllm_config = mock.MagicMock() + mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_producer") + + with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ + mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): + self.assertTrue(utils.enable_dsa_cp_with_layer_shard()) + + def test_enable_dsa_cp_with_layer_shard_rejects_kv_both(self): + mock_vllm_config = mock.MagicMock() + mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_both", is_kv_producer=True) + + with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ + mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): + self.assertFalse(utils.enable_dsa_cp_with_layer_shard()) + + def test_enable_dsa_cp_with_layer_shard_rejects_missing_kv_transfer(self): + mock_vllm_config = mock.MagicMock() + mock_vllm_config.kv_transfer_config = None + + with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ + mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): + self.assertFalse(utils.enable_dsa_cp_with_layer_shard()) + def test_vllm_version_is(self): with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}): with mock.patch("vllm.__version__", "1.0.0"): diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 85f19eeb..a92922f5 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -233,6 +233,20 @@ class NPUPlatform(Platform): def set_device(cls, device: torch.device): torch.npu.set_device(device) + @classmethod + def _validate_layer_sharding_config(cls, vllm_config: VllmConfig) -> None: + additional_config = vllm_config.additional_config or {} + layer_sharding = additional_config.get("layer_sharding") or [] + if not layer_sharding: + return + + kv_transfer_config = vllm_config.kv_transfer_config + if kv_transfer_config is not None and kv_transfer_config.kv_role != "kv_producer": + raise ValueError( + "additional_config.layer_sharding is only supported on P nodes " + "(kv_role='kv_producer') when KV transfer is enabled." + ) + @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: from vllm_ascend.quantization.utils import maybe_auto_detect_quantization @@ -240,6 +254,8 @@ class NPUPlatform(Platform): if vllm_config.model_config is not None: maybe_auto_detect_quantization(vllm_config) + cls._validate_layer_sharding_config(vllm_config) + # initialize ascend config from vllm additional_config cls._fix_incompatible_config(vllm_config) diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 766bd526..c79a8454 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1238,7 +1238,9 @@ def enable_dsa_cp_with_layer_shard() -> bool: vllm_config = get_current_vllm_config() # because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be # effectively hidden, it is enabled only during the prefill stage. - is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer + is_prefill_instance = ( + vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.kv_role == "kv_producer" + ) return is_prefill_instance