[refactor](UT,PCP,DCP) refactor pcp&dcp patches in UTs (#5505)
### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -190,15 +190,7 @@ class TestAscendMLAMetadata(TestBase):
|
||||
|
||||
class TestAscendMLAMetadataBuilder(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_default(self, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_ascend_mla_metadata_builder_default(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -209,22 +201,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = 0
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 2
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 2
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
ascend_config = MagicMock()
|
||||
@@ -239,16 +215,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp,
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_ascend_mla_metadata_builder_spec_decode(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -259,20 +226,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 3
|
||||
mock_vllm_config.speculative_config = mock_spec_config
|
||||
@@ -290,15 +243,8 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_ascend_mla_metadata_builder_build_full_graph(
|
||||
self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
self, mock_get_cos_and_sin_mla):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -310,20 +256,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_device = 'cpu'
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_spec_config = MagicMock()
|
||||
mock_spec_config.num_speculative_tokens = 1
|
||||
mock_spec_config.disable_padded_drafter_batch = True
|
||||
@@ -352,14 +284,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
[1, 2, 4, 5, 6, 6, 7, 8])
|
||||
self.assertEqual(metadata.decode.block_table.shape[0], 8)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_reorder_batch(self):
|
||||
ascend_config = MagicMock()
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
@@ -370,20 +295,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_vllm_config.scheduler_config.enable_chunked_prefill = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
@@ -411,16 +322,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
self.assertTrue(modified)
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp,
|
||||
mock_get_dcp_group,
|
||||
mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_pad_actual_seq_lens_q_mtp_disable_pad(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -432,20 +334,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config,
|
||||
mock_device)
|
||||
input_seq_lens = [1, 2, 4, 5]
|
||||
@@ -456,15 +344,7 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
num_reqs_pad_size, num_reqs, input_seq_lens)
|
||||
self.assertEqual(output_seq_lens, expect_output)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp,
|
||||
mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
def test_pad_actual_seq_lens_q_mtp_enable_pad(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
@@ -476,20 +356,6 @@ class TestAscendMLAMetadataBuilder(TestBase):
|
||||
mock_device = 'cpu'
|
||||
mock_vllm_config.speculative_config = None
|
||||
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
common_metadata = MagicMock()
|
||||
common_metadata.actual_seq_lengths_q = [2, 4, 6, 8]
|
||||
|
||||
@@ -530,22 +396,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.kv_cache_spec.num_heads = 32
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_prefix_no_cache_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_zeros,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -596,22 +454,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros)
|
||||
@patch("torch.Tensor.npu", new=lambda self: self)
|
||||
@patch("torch.npu.is_available")
|
||||
def test_build_chunked_prefix_metadata(self, mock_npu_available,
|
||||
mock_zeros, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_zeros,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_npu_available.return_value = False
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
def zeros_override(*args, **kwargs):
|
||||
kwargs.pop('pin_memory', None)
|
||||
@@ -663,18 +513,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_decode_only_metadata(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla):
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -718,18 +559,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
def test_build_for_graph_capture_decode_only(self,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
@@ -774,17 +607,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla")
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def test_build_for_graph_capture_prefill(self, mock_dcp_world_size,
|
||||
mock_get_pcp_group,
|
||||
mock_get_cos_and_sin_mla):
|
||||
mock_dcp_world_size.return_value = 1
|
||||
def test_build_for_graph_capture_prefill(self, mock_get_cos_and_sin_mla):
|
||||
torch.Tensor.pin_memory = lambda x: x # noqa
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.world_size = 1
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
@@ -820,23 +644,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase):
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp):
|
||||
def setUp(self, get_current_vllm_config, mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
mock_dcp.world_size = 1
|
||||
mock_dcp.rank_in_group = MagicMock()
|
||||
mock_dcp.device_group = MagicMock()
|
||||
mock_pcp.world_size = 1
|
||||
mock_pcp.rank_in_group = MagicMock()
|
||||
mock_pcp.device_group = MagicMock()
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
|
||||
Reference in New Issue
Block a user