[BugFix] Add async communication check for capturing mode (#8149)

### What this PR does / why we need it?
Introduce a check to not using asynchronous communication under
`enable_dsa_cp_with_layer_shard` branch on capturing mode. This change
prevents potential stream and event issues when operating in
graph/capturing mode, ensuring safer communication practices.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
E2E test with dsv32 + FC1 + FULL_DECODE_ONLY +
kv_transfer_config(kv_both)

---------

Signed-off-by: chenchuw886 <chenchuw@huawei.com>
Co-authored-by: chenchuw886 <chenchuw@huawei.com>
This commit is contained in:
Frank Chen
2026-04-12 21:52:54 +08:00
committed by GitHub
parent c1f323ee46
commit 31186a3a9d
8 changed files with 85 additions and 3 deletions

View File

@@ -35,6 +35,7 @@ class TestUtils(TestBase):
from vllm_ascend import platform
importlib.reload(platform)
utils.enable_dsa_cp_with_layer_shard.cache_clear()
def test_nd_to_nz_2d(self):
# can be divided by 16
@@ -131,6 +132,30 @@ class TestUtils(TestBase):
with mock.patch("torch.npu.current_stream") as mock_current_stream:
self.assertEqual(utils.current_stream(), mock_current_stream())
def test_enable_dsa_cp_with_layer_shard_accepts_kv_producer(self):
mock_vllm_config = mock.MagicMock()
mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_producer")
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
self.assertTrue(utils.enable_dsa_cp_with_layer_shard())
def test_enable_dsa_cp_with_layer_shard_rejects_kv_both(self):
mock_vllm_config = mock.MagicMock()
mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_both", is_kv_producer=True)
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
self.assertFalse(utils.enable_dsa_cp_with_layer_shard())
def test_enable_dsa_cp_with_layer_shard_rejects_missing_kv_transfer(self):
mock_vllm_config = mock.MagicMock()
mock_vllm_config.kv_transfer_config = None
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
self.assertFalse(utils.enable_dsa_cp_with_layer_shard())
def test_vllm_version_is(self):
with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}):
with mock.patch("vllm.__version__", "1.0.0"):