diff --git a/tests/ut/test_utils.py b/tests/ut/test_utils.py index 24711fbb..350c1a7c 100644 --- a/tests/ut/test_utils.py +++ b/tests/ut/test_utils.py @@ -36,6 +36,7 @@ class TestUtils(TestBase): from vllm_ascend import platform importlib.reload(platform) utils.enable_dsa_cp_with_layer_shard.cache_clear() + utils.enable_dsa_cp_with_o_proj_tp.cache_clear() def test_nd_to_nz_2d(self): # can be divided by 16 @@ -134,7 +135,9 @@ class TestUtils(TestBase): def test_enable_dsa_cp_with_layer_shard_accepts_kv_producer(self): mock_vllm_config = mock.MagicMock() - mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_producer") + mock_vllm_config.kv_transfer_config = mock.MagicMock( + kv_role="kv_producer", is_kv_producer=True, is_kv_consumer=False + ) with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): @@ -142,7 +145,9 @@ class TestUtils(TestBase): def test_enable_dsa_cp_with_layer_shard_rejects_kv_both(self): mock_vllm_config = mock.MagicMock() - mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_both", is_kv_producer=True) + mock_vllm_config.kv_transfer_config = mock.MagicMock( + kv_role="kv_both", is_kv_producer=True, is_kv_consumer=True + ) with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): @@ -156,6 +161,42 @@ class TestUtils(TestBase): mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): self.assertFalse(utils.enable_dsa_cp_with_layer_shard()) + def test_enable_dsa_cp_with_layer_shard_rejects_when_dsa_cp_disabled(self): + with mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=False): + self.assertFalse(utils.enable_dsa_cp_with_layer_shard()) + + def test_enable_dsa_cp_with_o_proj_tp_accepts_missing_kv_transfer(self): + mock_vllm_config = mock.MagicMock() + mock_vllm_config.kv_transfer_config = None + + with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ + mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): + self.assertTrue(utils.enable_dsa_cp_with_o_proj_tp()) + + def test_enable_dsa_cp_with_o_proj_tp_accepts_kv_both(self): + mock_vllm_config = mock.MagicMock() + mock_vllm_config.kv_transfer_config = mock.MagicMock( + kv_role="kv_both", is_kv_producer=True, is_kv_consumer=True + ) + + with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ + mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): + self.assertTrue(utils.enable_dsa_cp_with_o_proj_tp()) + + def test_enable_dsa_cp_with_o_proj_tp_rejects_single_role_pd(self): + mock_vllm_config = mock.MagicMock() + mock_vllm_config.kv_transfer_config = mock.MagicMock( + kv_role="kv_producer", is_kv_producer=True, is_kv_consumer=False + ) + + with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \ + mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True): + self.assertFalse(utils.enable_dsa_cp_with_o_proj_tp()) + + def test_enable_dsa_cp_with_o_proj_tp_rejects_when_dsa_cp_disabled(self): + with mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=False): + self.assertFalse(utils.enable_dsa_cp_with_o_proj_tp()) + def test_vllm_version_is(self): with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}): with mock.patch("vllm.__version__", "1.0.0"): diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 3b0632ba..2d477240 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1252,11 +1252,10 @@ def enable_dsa_cp_with_layer_shard() -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() - # because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be - # effectively hidden, it is enabled only during the prefill stage. - is_prefill_instance = ( - vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.kv_role == "kv_producer" - ) + kv_transfer_config = vllm_config.kv_transfer_config + # Layer sharding broadcast only pays off when it can be hidden by the + # heavier prefill-stage compute, so enable it only on the P-side instance. + is_prefill_instance = kv_transfer_config is not None and kv_transfer_config.kv_role == "kv_producer" return is_prefill_instance @@ -1267,9 +1266,12 @@ def enable_dsa_cp_with_o_proj_tp() -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() - # if is PD mix stage, using original TP o_proj weight, and also need to - # full gather for o_proj weight for prefill stage. - return vllm_config.kv_transfer_config is None + kv_transfer_config = vllm_config.kv_transfer_config + + # In PD-mixed mode, keep the original TP o_proj weight when: + # 1) KV pooling is disabled, or + # 2) KV pooling is enabled with kv_role == "kv_both". + return kv_transfer_config is None or kv_transfer_config.kv_role == "kv_both" def check_gdn_layer(vllm_config) -> bool: