[BugFix][DSv32] Fix DSA-CP PD role gating for deepseek v3.2 (v0.18.0) (#8291)
### What this PR does / why we need it? This PR backports the DSA-CP PD role gating fix to `releases/v0.18.0`. The existing helper logic on the release branch does not handle the PD mixed-role case correctly when deciding whether layer sharding or TP `o_proj` handling should be enabled. Layer sharding should only run on the P-side instance, while TP `o_proj` handling should stay enabled for normal non-PD deployments and for the PD mixed-role (`kv_both`) instance. This patch makes those conditions explicit and adds unit coverage for the allowed and disallowed combinations, including the DSA-CP-disabled path. Such wrong condition lead to **vllm serve failures** in case: **FC1 + PD-colocated KV pooling + no layer_sharding**, specifically causing: 1. insufficient Available KV cache memory 2. o_proj shape error in sfa_v1 attention module ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? E2E test with dsv32 + FC1 + FULL_DECODE_ONLY + kv_transfer_config(kv_both) + no layer_sharding --------- Signed-off-by: chenchuw886 <chenchuw@huawei.com> Co-authored-by: chenchuw886 <chenchuw@huawei.com>
This commit is contained in:
@@ -36,6 +36,7 @@ class TestUtils(TestBase):
|
|||||||
from vllm_ascend import platform
|
from vllm_ascend import platform
|
||||||
importlib.reload(platform)
|
importlib.reload(platform)
|
||||||
utils.enable_dsa_cp_with_layer_shard.cache_clear()
|
utils.enable_dsa_cp_with_layer_shard.cache_clear()
|
||||||
|
utils.enable_dsa_cp_with_o_proj_tp.cache_clear()
|
||||||
|
|
||||||
def test_nd_to_nz_2d(self):
|
def test_nd_to_nz_2d(self):
|
||||||
# can be divided by 16
|
# can be divided by 16
|
||||||
@@ -134,7 +135,9 @@ class TestUtils(TestBase):
|
|||||||
|
|
||||||
def test_enable_dsa_cp_with_layer_shard_accepts_kv_producer(self):
|
def test_enable_dsa_cp_with_layer_shard_accepts_kv_producer(self):
|
||||||
mock_vllm_config = mock.MagicMock()
|
mock_vllm_config = mock.MagicMock()
|
||||||
mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_producer")
|
mock_vllm_config.kv_transfer_config = mock.MagicMock(
|
||||||
|
kv_role="kv_producer", is_kv_producer=True, is_kv_consumer=False
|
||||||
|
)
|
||||||
|
|
||||||
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||||
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||||
@@ -142,7 +145,9 @@ class TestUtils(TestBase):
|
|||||||
|
|
||||||
def test_enable_dsa_cp_with_layer_shard_rejects_kv_both(self):
|
def test_enable_dsa_cp_with_layer_shard_rejects_kv_both(self):
|
||||||
mock_vllm_config = mock.MagicMock()
|
mock_vllm_config = mock.MagicMock()
|
||||||
mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_both", is_kv_producer=True)
|
mock_vllm_config.kv_transfer_config = mock.MagicMock(
|
||||||
|
kv_role="kv_both", is_kv_producer=True, is_kv_consumer=True
|
||||||
|
)
|
||||||
|
|
||||||
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||||
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||||
@@ -156,6 +161,42 @@ class TestUtils(TestBase):
|
|||||||
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||||
self.assertFalse(utils.enable_dsa_cp_with_layer_shard())
|
self.assertFalse(utils.enable_dsa_cp_with_layer_shard())
|
||||||
|
|
||||||
|
def test_enable_dsa_cp_with_layer_shard_rejects_when_dsa_cp_disabled(self):
|
||||||
|
with mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=False):
|
||||||
|
self.assertFalse(utils.enable_dsa_cp_with_layer_shard())
|
||||||
|
|
||||||
|
def test_enable_dsa_cp_with_o_proj_tp_accepts_missing_kv_transfer(self):
|
||||||
|
mock_vllm_config = mock.MagicMock()
|
||||||
|
mock_vllm_config.kv_transfer_config = None
|
||||||
|
|
||||||
|
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||||
|
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||||
|
self.assertTrue(utils.enable_dsa_cp_with_o_proj_tp())
|
||||||
|
|
||||||
|
def test_enable_dsa_cp_with_o_proj_tp_accepts_kv_both(self):
|
||||||
|
mock_vllm_config = mock.MagicMock()
|
||||||
|
mock_vllm_config.kv_transfer_config = mock.MagicMock(
|
||||||
|
kv_role="kv_both", is_kv_producer=True, is_kv_consumer=True
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||||
|
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||||
|
self.assertTrue(utils.enable_dsa_cp_with_o_proj_tp())
|
||||||
|
|
||||||
|
def test_enable_dsa_cp_with_o_proj_tp_rejects_single_role_pd(self):
|
||||||
|
mock_vllm_config = mock.MagicMock()
|
||||||
|
mock_vllm_config.kv_transfer_config = mock.MagicMock(
|
||||||
|
kv_role="kv_producer", is_kv_producer=True, is_kv_consumer=False
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||||
|
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||||
|
self.assertFalse(utils.enable_dsa_cp_with_o_proj_tp())
|
||||||
|
|
||||||
|
def test_enable_dsa_cp_with_o_proj_tp_rejects_when_dsa_cp_disabled(self):
|
||||||
|
with mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=False):
|
||||||
|
self.assertFalse(utils.enable_dsa_cp_with_o_proj_tp())
|
||||||
|
|
||||||
def test_vllm_version_is(self):
|
def test_vllm_version_is(self):
|
||||||
with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}):
|
with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}):
|
||||||
with mock.patch("vllm.__version__", "1.0.0"):
|
with mock.patch("vllm.__version__", "1.0.0"):
|
||||||
|
|||||||
@@ -1252,11 +1252,10 @@ def enable_dsa_cp_with_layer_shard() -> bool:
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
# because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
# effectively hidden, it is enabled only during the prefill stage.
|
# Layer sharding broadcast only pays off when it can be hidden by the
|
||||||
is_prefill_instance = (
|
# heavier prefill-stage compute, so enable it only on the P-side instance.
|
||||||
vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.kv_role == "kv_producer"
|
is_prefill_instance = kv_transfer_config is not None and kv_transfer_config.kv_role == "kv_producer"
|
||||||
)
|
|
||||||
return is_prefill_instance
|
return is_prefill_instance
|
||||||
|
|
||||||
|
|
||||||
@@ -1267,9 +1266,12 @@ def enable_dsa_cp_with_o_proj_tp() -> bool:
|
|||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
# if is PD mix stage, using original TP o_proj weight, and also need to
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
# full gather for o_proj weight for prefill stage.
|
|
||||||
return vllm_config.kv_transfer_config is None
|
# In PD-mixed mode, keep the original TP o_proj weight when:
|
||||||
|
# 1) KV pooling is disabled, or
|
||||||
|
# 2) KV pooling is enabled with kv_role == "kv_both".
|
||||||
|
return kv_transfer_config is None or kv_transfer_config.kv_role == "kv_both"
|
||||||
|
|
||||||
|
|
||||||
def check_gdn_layer(vllm_config) -> bool:
|
def check_gdn_layer(vllm_config) -> bool:
|
||||||
|
|||||||
Reference in New Issue
Block a user