[refactor](UT,PCP,DCP) refactor pcp&dcp patches in UTs (#5505)

### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-01-05 09:05:45 +08:00
committed by GitHub
parent 46c2fc6a3c
commit 96775a27a8
6 changed files with 128 additions and 531 deletions

View File

@@ -190,15 +190,7 @@ class TestAscendMLAMetadata(TestBase):
class TestAscendMLAMetadataBuilder(TestBase):
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_ascend_mla_metadata_builder_default(self, mock_dcp,
mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
def test_ascend_mla_metadata_builder_default(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -209,22 +201,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'
mock_dcp.world_size = 2
mock_dcp.rank_in_group = 0
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 2
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 2
mock_pcp.rank_in_group = 0
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 2
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
mock_vllm_config.speculative_config = None
ascend_config = MagicMock()
@@ -239,16 +215,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.enable_chunked_prefill)
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp,
mock_get_dcp_group,
mock_pcp,
mock_get_pcp_group):
def test_ascend_mla_metadata_builder_spec_decode(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -259,20 +226,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 1
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 1
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
mock_spec_config = MagicMock()
mock_spec_config.num_speculative_tokens = 3
mock_vllm_config.speculative_config = mock_spec_config
@@ -290,15 +243,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.enable_chunked_prefill)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_ascend_mla_metadata_builder_build_full_graph(
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
mock_get_cos_and_sin_mla):
self, mock_get_cos_and_sin_mla):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -310,20 +256,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_device = 'cpu'
torch.Tensor.pin_memory = lambda x: x # noqa
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 1
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 1
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
mock_spec_config = MagicMock()
mock_spec_config.num_speculative_tokens = 1
mock_spec_config.disable_padded_drafter_batch = True
@@ -352,14 +284,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
[1, 2, 4, 5, 6, 6, 7, 8])
self.assertEqual(metadata.decode.block_table.shape[0], 8)
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
def test_reorder_batch(self):
ascend_config = MagicMock()
mock_vllm_config = MagicMock()
@@ -370,20 +295,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
mock_device = 'cpu'
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 1
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 1
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
mock_vllm_config.speculative_config = None
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
@@ -411,16 +322,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2)
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp,
mock_get_dcp_group,
mock_pcp,
mock_get_pcp_group):
def test_pad_actual_seq_lens_q_mtp_disable_pad(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -432,20 +334,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_device = 'cpu'
mock_vllm_config.speculative_config = None
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 1
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 1
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
mock_device)
input_seq_lens = [1, 2, 4, 5]
@@ -456,15 +344,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
num_reqs_pad_size, num_reqs, input_seq_lens)
self.assertEqual(output_seq_lens, expect_output)
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state.get_dcp_group')
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp,
mock_get_dcp_group, mock_pcp,
mock_get_pcp_group):
def test_pad_actual_seq_lens_q_mtp_enable_pad(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
@@ -476,20 +356,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
mock_device = 'cpu'
mock_vllm_config.speculative_config = None
mock_dcp.world_size = 1
dcp_group = MagicMock(spec=GroupCoordinator)
dcp_group.rank_in_group = 0
dcp_group.world_size = 1
dcp_group.device_group = MagicMock()
mock_get_dcp_group.return_value = dcp_group
mock_pcp.world_size = 1
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.rank_in_group = 0
pcp_group.world_size = 1
pcp_group.device_group = MagicMock()
mock_get_pcp_group.return_value = pcp_group
common_metadata = MagicMock()
common_metadata.actual_seq_lengths_q = [2, 4, 6, 8]
@@ -530,22 +396,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.kv_cache_spec.num_heads = 32
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group,
mock_zeros,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
def zeros_override(*args, **kwargs):
kwargs.pop('pin_memory', None)
@@ -596,22 +454,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
@patch("torch.Tensor.npu", new=lambda self: self)
@patch("torch.npu.is_available")
def test_build_chunked_prefix_metadata(self, mock_npu_available,
mock_zeros, mock_dcp_world_size,
mock_get_pcp_group,
mock_zeros,
mock_get_cos_and_sin_mla):
mock_npu_available.return_value = False
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
def zeros_override(*args, **kwargs):
kwargs.pop('pin_memory', None)
@@ -663,18 +513,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_decode_only_metadata(self, mock_dcp_world_size,
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
@@ -718,18 +559,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
mock_get_pcp_group,
def test_build_for_graph_capture_decode_only(self,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
@@ -774,17 +607,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
@patch('vllm.distributed.parallel_state.get_pcp_group')
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
return_value=1)
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
mock_get_pcp_group,
mock_get_cos_and_sin_mla):
mock_dcp_world_size.return_value = 1
def test_build_for_graph_capture_prefill(self, mock_get_cos_and_sin_mla):
torch.Tensor.pin_memory = lambda x: x # noqa
pcp_group = MagicMock(spec=GroupCoordinator)
pcp_group.world_size = 1
mock_get_pcp_group.return_value = pcp_group
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
@@ -820,23 +644,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
class TestAscendMLAImpl(TestBase):
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state._DCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp):
def setUp(self, get_current_vllm_config, mock_tp):
mock_tp.world_size = 2
mock_tp.rank_in_group = MagicMock()
mock_tp.device_group = MagicMock()
mock_dcp.world_size = 1
mock_dcp.rank_in_group = MagicMock()
mock_dcp.device_group = MagicMock()
mock_pcp.world_size = 1
mock_pcp.rank_in_group = MagicMock()
mock_pcp.device_group = MagicMock()
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()