[BugFix] Add async communication check for capturing mode (#8149)
### What this PR does / why we need it? Introduce a check to not using asynchronous communication under `enable_dsa_cp_with_layer_shard` branch on capturing mode. This change prevents potential stream and event issues when operating in graph/capturing mode, ensuring safer communication practices. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? E2E test with dsv32 + FC1 + FULL_DECODE_ONLY + kv_transfer_config(kv_both) --------- Signed-off-by: chenchuw886 <chenchuw@huawei.com> Co-authored-by: chenchuw886 <chenchuw@huawei.com>
This commit is contained in:
@@ -438,6 +438,36 @@ class TestNPUPlatform(TestBase):
|
||||
|
||||
self.assertEqual(vllm_config.cache_config.block_size, 512)
|
||||
|
||||
def test_validate_layer_sharding_config_accepts_single_node(self):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
||||
vllm_config.kv_transfer_config = None
|
||||
|
||||
self.platform._validate_layer_sharding_config(vllm_config)
|
||||
|
||||
def test_validate_layer_sharding_config_accepts_kv_producer(self):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
||||
vllm_config.kv_transfer_config = MagicMock(is_kv_producer=True, kv_role="kv_producer")
|
||||
|
||||
self.platform._validate_layer_sharding_config(vllm_config)
|
||||
|
||||
def test_validate_layer_sharding_config_rejects_non_kv_producer(self):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
||||
vllm_config.kv_transfer_config = MagicMock(is_kv_producer=False, kv_role="kv_consumer")
|
||||
|
||||
with pytest.raises(ValueError, match="layer_sharding is only supported on P nodes"):
|
||||
self.platform._validate_layer_sharding_config(vllm_config)
|
||||
|
||||
def test_validate_layer_sharding_config_rejects_kv_both(self):
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.additional_config = {"layer_sharding": ["q_b_proj", "o_proj"]}
|
||||
vllm_config.kv_transfer_config = MagicMock(is_kv_producer=True, kv_role="kv_both")
|
||||
|
||||
with pytest.raises(ValueError, match="layer_sharding is only supported on P nodes"):
|
||||
self.platform._validate_layer_sharding_config(vllm_config)
|
||||
|
||||
@patch("vllm_ascend.quantization.utils.maybe_auto_detect_quantization")
|
||||
@patch("vllm_ascend.utils.get_ascend_device_type", return_value=AscendDeviceType.A3)
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
|
||||
@@ -35,6 +35,7 @@ class TestUtils(TestBase):
|
||||
|
||||
from vllm_ascend import platform
|
||||
importlib.reload(platform)
|
||||
utils.enable_dsa_cp_with_layer_shard.cache_clear()
|
||||
|
||||
def test_nd_to_nz_2d(self):
|
||||
# can be divided by 16
|
||||
@@ -131,6 +132,30 @@ class TestUtils(TestBase):
|
||||
with mock.patch("torch.npu.current_stream") as mock_current_stream:
|
||||
self.assertEqual(utils.current_stream(), mock_current_stream())
|
||||
|
||||
def test_enable_dsa_cp_with_layer_shard_accepts_kv_producer(self):
|
||||
mock_vllm_config = mock.MagicMock()
|
||||
mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_producer")
|
||||
|
||||
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||
self.assertTrue(utils.enable_dsa_cp_with_layer_shard())
|
||||
|
||||
def test_enable_dsa_cp_with_layer_shard_rejects_kv_both(self):
|
||||
mock_vllm_config = mock.MagicMock()
|
||||
mock_vllm_config.kv_transfer_config = mock.MagicMock(kv_role="kv_both", is_kv_producer=True)
|
||||
|
||||
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||
self.assertFalse(utils.enable_dsa_cp_with_layer_shard())
|
||||
|
||||
def test_enable_dsa_cp_with_layer_shard_rejects_missing_kv_transfer(self):
|
||||
mock_vllm_config = mock.MagicMock()
|
||||
mock_vllm_config.kv_transfer_config = None
|
||||
|
||||
with mock.patch("vllm.config.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
mock.patch("vllm_ascend.utils.enable_dsa_cp", return_value=True):
|
||||
self.assertFalse(utils.enable_dsa_cp_with_layer_shard())
|
||||
|
||||
def test_vllm_version_is(self):
|
||||
with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}):
|
||||
with mock.patch("vllm.__version__", "1.0.0"):
|
||||
|
||||
Reference in New Issue
Block a user