diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 890a4794..3df7cb16 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -1,71 +1,19 @@ -from functools import wraps from typing import List from unittest.mock import MagicMock, patch import torch -from vllm.distributed.parallel_state import GroupCoordinator, all_gather_fake +from tests.ut.attention.utils import patch_distributed_groups from tests.ut.base import TestBase from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl from vllm_ascend.attention.attention_v1 import (AscendMetadata, AscendMetadataForPrefill) -def patch_distributed_groups(dcp_size=1, dcp_rank=0, pcp_size=1, pcp_rank=0): - """ - Decorator to patch common distributed group mocks with configuration - - Args: - dcp_size: DCP world size (default: 1) - dcp_rank: DCP rank (default: 0) - pcp_size: PCP world size (default: 1) - pcp_rank: PCP rank (default: 0) - """ - - def decorator(func): - - @wraps(func) - @patch('torch.distributed.all_to_all_single') - @patch('vllm.distributed.parallel_state._PCP') - def wrapper(self, mock_pcp, mock_all_to_all_single, *args, **kwargs): - mock_pcp.world_size = pcp_size - mock_pcp.rank_in_group = pcp_rank - - mock_pcp.rank_in_group = pcp_rank - mock_pcp.world_size = pcp_size - mock_pcp.device_group = MagicMock() - mock_pcp.all_gather = MagicMock() - mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake( - input_, dim, pcp_size, "mock") - - mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_( - input) - - return func(self, mock_all_to_all_single, mock_pcp, *args, - **kwargs) - - return wrapper - - return decorator - - class TestAscendAttentionCPImpl(TestBase): - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def setUp(self, mock_get_dcp_size, mock_dcp, mock_pcp): - mock_dcp.world_size = 2 - mock_dcp.rank_in_group = 0 - mock_dcp.device_group = MagicMock() - - mock_pcp.world_size = 2 - mock_pcp.rank_in_group = 0 - mock_pcp.device_group = MagicMock() - + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) + def setUp(self): self.layer = MagicMock() self.layer.layer_name = "test_layer" self.layer._k_scale_float = 1.0 @@ -134,41 +82,12 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(output.shape[1], 4) self.assertEqual(output.shape[2], 128) - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP') - @patch('vllm_ascend.attention.attention_cp.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP') @patch("torch_npu.npu_fused_infer_attention_score") - @patch("torch.distributed.all_gather") - @patch("torch.distributed.all_to_all_single") @patch('vllm_ascend.attention.attention_cp.get_forward_context') - def test_forward_decode_pcp_dcp(self, mock_get_forward_context, - mock_all_to_all_single, mock_all_gather, - mock_npu_fused_infer_attention_score, - mock_dcp, mock_get_dcp_group, mock_pcp, - mock_pcp_group): - - def mock_all_gather_func(tensor, dim): - return torch.cat([tensor, tensor], dim=dim) - - mock_dcp.world_size = 2 - mock_dcp.rank_in_group = 0 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 2 - dcp_group.device_group = MagicMock() - dcp_group.all_gather = mock_all_gather_func - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 2 - mock_pcp.rank_in_group = 0 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 2 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.all_gather = mock_all_gather_func - mock_pcp_group.return_value = pcp_group - + @patch_distributed_groups(dcp_size=2, pcp_size=2) + def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp, + mock_get_forward_context, + mock_npu_fused_infer_attention_score): query = torch.randn(2, 4, 128) self.impl.key_cache = torch.randn(100, 128, 1, 128) self.impl.value_cache = torch.randn(100, 128, 1, 128) @@ -185,15 +104,6 @@ class TestAscendAttentionCPImpl(TestBase): mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( - input) - - def mock_all_gather_func1(tensor_list, tensor, group=None): - tensor_list[0] = tensor - tensor_list[1] = tensor.clone() - - mock_all_gather.side_effect = mock_all_gather_func1 - def mock_npu_fused_infer_attention_score_func(query, k_nope, value, **common_kwargs): mock_output = torch.randn_like(query) @@ -213,25 +123,10 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(output.shape[1], 4) self.assertEqual(output.shape[2], 128) - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP') - @patch('vllm_ascend.attention.attention_cp.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP') - def test_prefill_query_all_gather(self, mock_dcp, mock_get_dcp_group, - mock_pcp, mock_get_pcp_group): + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) + def test_prefill_query_all_gather(self): query = torch.randn(2, 4, 128) - def mock_all_gather_func(tensor, dim): - return torch.cat([tensor, tensor], dim=dim) - - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.all_gather = mock_all_gather_func - mock_get_dcp_group.return_value = dcp_group - - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.all_gather = mock_all_gather_func - mock_get_pcp_group.return_value = pcp_group - attn_metadata = MagicMock() attn_metadata.prefill = MagicMock() attn_metadata.prefill.chunked_context = MagicMock() @@ -243,9 +138,9 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(output.shape[1], 8) self.assertEqual(output.shape[2], 128) - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') @patch('torch.ops.npu.npu_fused_infer_attention_score') - def test_compute_prefill_context(self, mock_npu_attention, mock_pcp_group): + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) + def test_compute_prefill_context(self, mock_npu_attention): block_num = 100 block_size = 128 @@ -284,13 +179,6 @@ class TestAscendAttentionCPImpl(TestBase): self.impl._load_kv_for_chunk = MagicMock() self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk - def mock_all_gather_func(tensor, dim): - return torch.cat([tensor, tensor], dim=dim) - - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.all_gather = mock_all_gather_func - mock_pcp_group.return_value = pcp_group - mock_npu_attention.return_value = torch.randn(batch_size, num_heads, head_size), torch.randn( batch_size, @@ -341,11 +229,9 @@ class TestAscendAttentionCPImpl(TestBase): self.assertEqual(value.shape[1], num_heads) self.assertEqual(value.shape[2], head_size) - @patch('vllm_ascend.attention.attention_cp.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP') @patch('torch_npu._npu_reshape_and_cache') - def test_reshape_and_cache(self, mock_npu_reshape_and_cache, mock_pcp, - mock_get_pcp_group): + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) + def test_reshape_and_cache(self, mock_npu_reshape_and_cache): num_tokens = 4 block_num = 100 block_size = 128 @@ -369,13 +255,6 @@ class TestAscendAttentionCPImpl(TestBase): key = torch.randn(num_tokens, num_heads, head_size) value = torch.randn(num_tokens, num_heads, head_size) - def mock_all_gather_func(tensor, dim): - return torch.cat([tensor, tensor], dim=dim) - - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.all_gather = mock_all_gather_func - mock_get_pcp_group.return_value = pcp_group - key, value = self.impl.reshape_and_cache(key, value, kv_cache, attn_metadata) self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size) @@ -388,30 +267,8 @@ class TestAscendAttentionCPImpl(TestBase): class TestUpdateNpuAttnOutLse(TestBase): - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - + @patch_distributed_groups(needs_mocks=False) + def setUp(self): self.layer = MagicMock() self.layer.layer_name = "test_layer" self.layer._k_scale_float = 1.0 @@ -443,6 +300,7 @@ class TestUpdateNpuAttnOutLse(TestBase): kv_sharing_target_layer_name=None) self.impl.pcp_size = 1 + self.impl.dcp_size = 1 self.batch_size = 2 @@ -730,7 +588,7 @@ class TestUpdateNpuAttnOutLse(TestBase): @patch_distributed_groups(dcp_size=2, pcp_size=3) def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single, - mock_pcp): + mock_dcp, mock_pcp): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -759,7 +617,7 @@ class TestUpdateNpuAttnOutLse(TestBase): @patch_distributed_groups(dcp_size=2) def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single, - mock_pcp): + mock_dcp, mock_pcp): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) @@ -789,7 +647,7 @@ class TestUpdateNpuAttnOutLse(TestBase): @patch_distributed_groups(pcp_size=2) def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single, - mock_pcp): + mock_dcp, mock_pcp): # Mock input data prefix_chunk_output = torch.randn(2, 4, 8) prefix_chunk_lse = torch.randn(2, 4, 1) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 5746099c..4b82320b 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -1,7 +1,6 @@ from unittest.mock import MagicMock, patch import torch -from vllm.distributed.parallel_state import GroupCoordinator from tests.ut.base import TestBase from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend, @@ -50,30 +49,7 @@ class TestAscendAttentionBackend(TestBase): class TestAscendAttentionMetadataBuilder(TestBase): - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - + def setUp(self): self.mock_vllm_config = MagicMock() self.mock_vllm_config.speculative_config = None self.mock_vllm_config.model_config.max_model_len = 640 @@ -126,30 +102,7 @@ class TestAscendAttentionMetadataBuilder(TestBase): class TestAscendAttentionBackendImpl(TestBase): - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - + def setUp(self): self.layer = MagicMock() self.layer.layer_name = "test_layer" self.layer._k_scale_float = 1.0 diff --git a/tests/ut/attention/test_mla_cp.py b/tests/ut/attention/test_mla_cp.py index a40662d6..ce50773a 100755 --- a/tests/ut/attention/test_mla_cp.py +++ b/tests/ut/attention/test_mla_cp.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import torch from vllm.distributed.parallel_state import GroupCoordinator +from tests.ut.attention.utils import patch_distributed_groups from tests.ut.base import TestBase from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState @@ -159,29 +160,18 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes, class TestAscendMLAImpl(TestBase): - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) @patch('vllm.distributed.parallel_state._TP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm.distributed.get_tensor_model_parallel_world_size", return_value=2) @patch("vllm_ascend.attention.mla_v1.get_current_vllm_config") @patch("vllm_ascend.attention.mla_v1.get_ascend_config") + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size, - mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp): + mock_tp): mock_tp.world_size = 2 mock_tp.rank_in_group = MagicMock() mock_tp.device_group = MagicMock() - mock_dcp.world_size = 2 - mock_dcp.rank_in_group = MagicMock() - mock_dcp.device_group = MagicMock() - mock_pcp.world_size = 2 - mock_pcp.rank_in_group = MagicMock() - mock_pcp.device_group = MagicMock() vllm_config = MagicMock() speculative_config = MagicMock() model_config = MagicMock() @@ -252,12 +242,11 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(self.impl.pcp_size, 2) self.assertEqual(self.impl.dcp_size, 2) - @patch('vllm_ascend.attention.mla_cp.get_dcp_group') @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_mla_preprocess_dcp(self, magic_npu_fetch, - mock_maybe_all_gather_and_maybe_unpad, - mock_get_dcp_group): + mock_maybe_all_gather_and_maybe_unpad): self.impl.num_kv_heads = 1 self.impl.num_heads = 16 @@ -278,14 +267,6 @@ class TestAscendMLAImpl(TestBase): self.impl.qk_rope_head_dim) kv_cache = (kv_cache0, kv_cache1) - mock_dcp_group = MagicMock() - - def mock_all_gather_func(tensor, dim): - return torch.cat([tensor, tensor], dim=dim) - - mock_dcp_group.all_gather = mock_all_gather_func - mock_get_dcp_group.return_value = mock_dcp_group - attn_metadata = MagicMock() attn_metadata.num_decodes = 2 attn_metadata.num_prefills = 0 @@ -337,12 +318,11 @@ class TestAscendMLAImpl(TestBase): self.assertIsNone(prefill_res) @patch('torch_npu._npu_reshape_and_cache') - @patch('vllm_ascend.attention.mla_cp.get_pcp_group') @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_mla_preprocess_pcp(self, magic_npu_fetch, mock_maybe_all_gather_and_maybe_unpad, - mock_get_pcp_group, mock_npu_reshape_and_cache): self.impl.num_kv_heads = 1 self.impl.num_heads = 16 @@ -363,14 +343,6 @@ class TestAscendMLAImpl(TestBase): self.impl.qk_rope_head_dim) kv_cache = (kv_cache0, kv_cache1) - mock_pcp_group = MagicMock() - - def mock_all_gather_func(tensor, dim): - return torch.cat([tensor, tensor], dim=dim) - - mock_pcp_group.all_gather = mock_all_gather_func - mock_get_pcp_group.return_value = mock_pcp_group - attn_metadata = MagicMock() attn_metadata.num_decodes = 0 attn_metadata.num_prefills = 2 @@ -451,10 +423,8 @@ class TestAscendMLAImpl(TestBase): self.assertIsNone(decode_res) self.assertIsNotNone(prefill_res) - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("torch.distributed.all_to_all_single") - def test_process_attn_out_lse(self, mock_all_to_all_single, mock_pcp): + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) + def test_process_attn_out_lse(self): self.impl.dcp_size = 2 self.impl.pcp_size = 2 @@ -465,14 +435,6 @@ class TestAscendMLAImpl(TestBase): attn_output = torch.randn(B, N, self.impl.kv_lora_rank) softmax_lse = torch.randn(B, N, 1) - mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( - input) - - def make_all_gather(ws): - return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - - mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2)) - decode_metadata = MagicMock() decode_metadata.actual_seq_lengths_q = MagicMock() decode_metadata.seq_lens_list = MagicMock() @@ -486,16 +448,13 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[1], N) self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1) - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch("torch.distributed.all_to_all_single") @patch('vllm_ascend.attention.mla_cp.get_forward_context') @patch("torch_npu.atb.npu_multi_head_latent_attention") @patch('torch_npu.npu_attention_update') + @patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False) def test_forward_decode_pcp_dcp(self, mock_npu_attention_update, mock_npu_multi_head_latent_attention, - mock_get_forward_context, - mock_all_to_all_single, mock_pcp): + mock_get_forward_context): self.impl.dcp_size = 2 self.impl.pcp_size = 2 self.impl.num_kv_heads = 1 @@ -531,14 +490,6 @@ class TestAscendMLAImpl(TestBase): ] mock_get_forward_context.return_value = MagicMock(capturing=False) - mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( - input) - - def make_all_gather(ws): - return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - - mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2)) - self.impl._v_up_proj = MagicMock() self.impl._v_up_proj.return_value = torch.randn( B, self.impl.v_head_dim) @@ -549,17 +500,12 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(result.shape[0], B) self.assertEqual(result.shape[1], self.impl.v_head_dim) - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("torch_npu.atb.npu_paged_cache_load") @patch("torch_npu.atb.npu_ring_mla") - def test_compute_prefill_context_with_dcp_pcp(self, mock_ring, mock_load, - mock_dcp, mock_pcp): - - def mock_all_gather(ws): - return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) + @patch_distributed_groups(dcp_size=2, pcp_size=2) + def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp, + mock_pcp, mock_ring, + mock_load): def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen, head_num, kv_head_num, pre_out, prev_lse, qk_scale, @@ -620,10 +566,8 @@ class TestAscendMLAImpl(TestBase): torch.ones(10, 10, dtype=torch.float16), 1) for test_case in test_cases: pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case - mock_dcp.all_gather = MagicMock( - side_effect=mock_all_gather(dcp_size)) - mock_pcp.all_gather = MagicMock( - side_effect=mock_all_gather(pcp_size)) + mock_dcp.world_size = dcp_size + mock_pcp.world_size = pcp_size assert len(nums_tokens_per_rank) == len(nums_all_rank_context) nums_context_per_rank = [] for num_all_rank_context in nums_all_rank_context: @@ -687,18 +631,9 @@ class TestAscendMLAImpl(TestBase): self.assertEqual(out.shape, prefix_out.shape) self.assertEqual(lse.shape, prefix_lse.shape) - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group, - mock_pcp, mock_get_pcp_group): - - def mock_all_gather(ws): - return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - + @patch_distributed_groups(dcp_size=2, pcp_size=2) + def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp, + mock_pcp): BLOCK_SIZE = 128 # fixed max_model_len = 4096 max_num_seqs = 25 @@ -714,11 +649,12 @@ class TestAscendMLAImpl(TestBase): if pcp_size * dcp_size == 1: continue self.impl.dcp_size = dcp_size + mock_dcp.world_size = dcp_size + mock_dcp.all_gather.reset_mock() self.impl.pcp_size = pcp_size - mock_dcp.all_gather = MagicMock( - side_effect=mock_all_gather(dcp_size)) - mock_pcp.all_gather = MagicMock( - side_effect=mock_all_gather(pcp_size)) + mock_pcp.world_size = pcp_size + mock_pcp.all_gather.reset_mock() + chunked_prefill_workspace_size = min( max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE), 128 * 1024) @@ -918,17 +854,17 @@ class TestAscendMLAImpl(TestBase): 1 + (kv_with_q_tail_nomask_idx.shape[0] != 0)) mock_npu_ring_mla.reset_mock() - @patch("torch.distributed.all_to_all_single") - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_process_attn_out_lse_with_dcp_pcp(self, mock_pcp, - mock_all_to_all): + @patch_distributed_groups(dcp_size=2, pcp_size=2) + def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp, + mock_pcp): B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8] test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)] for test_case in test_cases: print(test_case) self.impl.dcp_size = test_case[0] self.impl.pcp_size = test_case[1] + mock_dcp.world_size = test_case[0] + mock_pcp.world_size = test_case[1] # Inputs attn_output = torch.randn(B, H, D) softmax_lse = torch.randn(B, H, 1) @@ -936,17 +872,6 @@ class TestAscendMLAImpl(TestBase): decode_meta = MagicMock() decode_meta.batch_seq_mask = batch_seq_mask - def mock_all_to_all_side_effect(output, input, group=None): - output.copy_(input) - - mock_all_to_all.side_effect = mock_all_to_all_side_effect - - def mock_all_gather(ws): - return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim) - - mock_pcp.all_gather = MagicMock( - side_effect=mock_all_gather(self.impl.pcp_size)) - result = self.impl._process_attn_out_lse(attn_output, softmax_lse, decode_meta) # [PCP * S, DCP * H, D + 1] diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index ae51a875..9f2f61aa 100755 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -190,15 +190,7 @@ class TestAscendMLAMetadata(TestBase): class TestAscendMLAMetadataBuilder(TestBase): - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_ascend_mla_metadata_builder_default(self, mock_dcp, - mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): + def test_ascend_mla_metadata_builder_default(self): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -209,22 +201,6 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.enable_chunked_prefill = False mock_device = 'cpu' - mock_dcp.world_size = 2 - mock_dcp.rank_in_group = 0 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 2 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 2 - mock_pcp.rank_in_group = 0 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 2 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - mock_vllm_config.speculative_config = None ascend_config = MagicMock() @@ -239,16 +215,7 @@ class TestAscendMLAMetadataBuilder(TestBase): builder.chunked_prefill_enabled, mock_vllm_config.scheduler_config.enable_chunked_prefill) - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_ascend_mla_metadata_builder_spec_decode(self, mock_dcp, - mock_get_dcp_group, - mock_pcp, - mock_get_pcp_group): + def test_ascend_mla_metadata_builder_spec_decode(self): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -259,20 +226,6 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.enable_chunked_prefill = False mock_device = 'cpu' - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - mock_spec_config = MagicMock() mock_spec_config.num_speculative_tokens = 3 mock_vllm_config.speculative_config = mock_spec_config @@ -290,15 +243,8 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.enable_chunked_prefill) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) def test_ascend_mla_metadata_builder_build_full_graph( - self, mock_dcp, mock_get_dcp_group, mock_pcp, mock_get_pcp_group, - mock_get_cos_and_sin_mla): + self, mock_get_cos_and_sin_mla): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -310,20 +256,6 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_device = 'cpu' torch.Tensor.pin_memory = lambda x: x # noqa - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - mock_spec_config = MagicMock() mock_spec_config.num_speculative_tokens = 1 mock_spec_config.disable_padded_drafter_batch = True @@ -352,14 +284,7 @@ class TestAscendMLAMetadataBuilder(TestBase): [1, 2, 4, 5, 6, 6, 7, 8]) self.assertEqual(metadata.decode.block_table.shape[0], 8) - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_reorder_batch(self, mock_dcp, mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): + def test_reorder_batch(self): ascend_config = MagicMock() mock_vllm_config = MagicMock() @@ -370,20 +295,6 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_vllm_config.scheduler_config.enable_chunked_prefill = False mock_device = 'cpu' - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - mock_vllm_config.speculative_config = None with patch("vllm_ascend.attention.mla_v1.get_ascend_config", @@ -411,16 +322,7 @@ class TestAscendMLAMetadataBuilder(TestBase): self.assertTrue(modified) input_batch.swap_states.assert_called_once_with(1, 2) - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_pad_actual_seq_lens_q_mtp_disable_pad(self, mock_dcp, - mock_get_dcp_group, - mock_pcp, - mock_get_pcp_group): + def test_pad_actual_seq_lens_q_mtp_disable_pad(self): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -432,20 +334,6 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_device = 'cpu' mock_vllm_config.speculative_config = None - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - builder = AscendMLAMetadataBuilder(None, None, mock_vllm_config, mock_device) input_seq_lens = [1, 2, 4, 5] @@ -456,15 +344,7 @@ class TestAscendMLAMetadataBuilder(TestBase): num_reqs_pad_size, num_reqs, input_seq_lens) self.assertEqual(output_seq_lens, expect_output) - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state.get_dcp_group') - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - def test_pad_actual_seq_lens_q_mtp_enable_pad(self, mock_dcp, - mock_get_dcp_group, mock_pcp, - mock_get_pcp_group): + def test_pad_actual_seq_lens_q_mtp_enable_pad(self): mock_vllm_config = MagicMock() mock_vllm_config.model_config.max_model_len = 1024 mock_vllm_config.model_config.get_head_size.return_value = 64 @@ -476,20 +356,6 @@ class TestAscendMLAMetadataBuilder(TestBase): mock_device = 'cpu' mock_vllm_config.speculative_config = None - mock_dcp.world_size = 1 - dcp_group = MagicMock(spec=GroupCoordinator) - dcp_group.rank_in_group = 0 - dcp_group.world_size = 1 - dcp_group.device_group = MagicMock() - mock_get_dcp_group.return_value = dcp_group - - mock_pcp.world_size = 1 - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.rank_in_group = 0 - pcp_group.world_size = 1 - pcp_group.device_group = MagicMock() - mock_get_pcp_group.return_value = pcp_group - common_metadata = MagicMock() common_metadata.actual_seq_lengths_q = [2, 4, 6, 8] @@ -530,22 +396,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.kv_cache_spec.num_heads = 32 @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) @patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros) @patch("torch.Tensor.npu", new=lambda self: self) @patch("torch.npu.is_available") def test_build_prefix_no_cache_metadata(self, mock_npu_available, - mock_zeros, mock_dcp_world_size, - mock_get_pcp_group, + mock_zeros, mock_get_cos_and_sin_mla): mock_npu_available.return_value = False - mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.world_size = 1 - mock_get_pcp_group.return_value = pcp_group def zeros_override(*args, **kwargs): kwargs.pop('pin_memory', None) @@ -596,22 +454,14 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) @patch("vllm_ascend.attention.mla_v1.torch.zeros", wraps=torch.zeros) @patch("torch.Tensor.npu", new=lambda self: self) @patch("torch.npu.is_available") def test_build_chunked_prefix_metadata(self, mock_npu_available, - mock_zeros, mock_dcp_world_size, - mock_get_pcp_group, + mock_zeros, mock_get_cos_and_sin_mla): mock_npu_available.return_value = False - mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.world_size = 1 - mock_get_pcp_group.return_value = pcp_group def zeros_override(*args, **kwargs): kwargs.pop('pin_memory', None) @@ -663,18 +513,9 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_build_decode_only_metadata(self, mock_dcp_world_size, - mock_get_pcp_group, - mock_get_cos_and_sin_mla): - mock_dcp_world_size.return_value = 1 + def test_build_decode_only_metadata(self, mock_get_cos_and_sin_mla): torch.Tensor.pin_memory = lambda x: x # noqa - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.world_size = 1 - mock_get_pcp_group.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 1, 2, 3]), query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), @@ -718,18 +559,10 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_build_for_graph_capture_decode_only(self, mock_dcp_world_size, - mock_get_pcp_group, + def test_build_for_graph_capture_decode_only(self, mock_get_cos_and_sin_mla): - mock_dcp_world_size.return_value = 1 torch.Tensor.pin_memory = lambda x: x # noqa - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.world_size = 1 - mock_get_pcp_group.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 1, 2, 3]), query_start_loc_cpu=torch.tensor([0, 1, 2, 3]), @@ -774,17 +607,8 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): self.assertEqual(metadata.head_dim, self.kv_cache_spec.head_size) @patch("vllm_ascend.attention.mla_v1.get_cos_and_sin_mla") - @patch('vllm.distributed.parallel_state.get_pcp_group') - @patch("vllm.distributed.get_decode_context_model_parallel_world_size", - return_value=1) - def test_build_for_graph_capture_prefill(self, mock_dcp_world_size, - mock_get_pcp_group, - mock_get_cos_and_sin_mla): - mock_dcp_world_size.return_value = 1 + def test_build_for_graph_capture_prefill(self, mock_get_cos_and_sin_mla): torch.Tensor.pin_memory = lambda x: x # noqa - pcp_group = MagicMock(spec=GroupCoordinator) - pcp_group.world_size = 1 - mock_get_pcp_group.return_value = pcp_group common_attn_metadata = AscendCommonAttentionMetadata( query_start_loc=torch.tensor([0, 3, 7]), query_start_loc_cpu=torch.tensor([0, 3, 7]), @@ -820,23 +644,13 @@ class TestAscendMLAMetadataBuilderBuild(TestBase): class TestAscendMLAImpl(TestBase): - @patch('vllm.distributed.parallel_state._PCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) - @patch('vllm.distributed.parallel_state._DCP', - new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch('vllm.distributed.parallel_state._TP', new_callable=lambda: MagicMock(spec=GroupCoordinator)) @patch("vllm_ascend.attention.mla_v1.get_current_vllm_config") - def setUp(self, get_current_vllm_config, mock_tp, mock_dcp, mock_pcp): + def setUp(self, get_current_vllm_config, mock_tp): mock_tp.world_size = 2 mock_tp.rank_in_group = MagicMock() mock_tp.device_group = MagicMock() - mock_dcp.world_size = 1 - mock_dcp.rank_in_group = MagicMock() - mock_dcp.device_group = MagicMock() - mock_pcp.world_size = 1 - mock_pcp.rank_in_group = MagicMock() - mock_pcp.device_group = MagicMock() vllm_config = MagicMock() speculative_config = MagicMock() model_config = MagicMock() diff --git a/tests/ut/attention/utils.py b/tests/ut/attention/utils.py new file mode 100644 index 00000000..fcc50371 --- /dev/null +++ b/tests/ut/attention/utils.py @@ -0,0 +1,64 @@ +from functools import wraps +from unittest.mock import MagicMock, patch + +from vllm.distributed.parallel_state import all_gather_fake + + +def patch_distributed_groups(dcp_size=1, + dcp_rank=0, + pcp_size=1, + pcp_rank=0, + needs_mocks=True): + """ + Decorator to patch common distributed group mocks with configuration + + Args: + dcp_size: DCP world size (default: 1) + dcp_rank: DCP rank (default: 0) + pcp_size: PCP world size (default: 1) + pcp_rank: PCP rank (default: 0) + needs_mocks: Whether to pass mock objects as the first arguments + after 'self' to the decorated function. + If True, the decorated function receives: + func(self, mock_all_to_all_single, mock_dcp, mock_pcp, *args, **kwargs) + If False, mocks are not passed and function receives: + func(self, *args, **kwargs) + (default: True) + """ + + def decorator(func): + + @wraps(func) + @patch('torch.distributed.all_to_all_single') + @patch('vllm.distributed.parallel_state._PCP') + @patch('vllm.distributed.parallel_state._DCP') + def wrapper(self, mock_dcp, mock_pcp, mock_all_to_all_single, *args, + **kwargs): + mock_dcp.rank_in_group = dcp_rank + mock_dcp.world_size = dcp_size + mock_dcp.device_group = MagicMock() + + mock_dcp.all_gather = MagicMock() + mock_dcp.all_gather.side_effect = lambda input_, dim: all_gather_fake( + input_, dim, mock_dcp.world_size, "mock_dcp_group") + + mock_pcp.rank_in_group = pcp_rank + mock_pcp.world_size = pcp_size + mock_pcp.device_group = MagicMock() + + mock_pcp.all_gather = MagicMock() + mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake( + input_, dim, mock_pcp.world_size, "mock_pcp_group") + + mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_( + input) + + if needs_mocks: + return func(self, mock_all_to_all_single, mock_dcp, mock_pcp, + *args, **kwargs) + else: + return func(self, *args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 0b7ff138..71e6f1a4 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -12,7 +12,6 @@ from unittest.mock import MagicMock, patch import msgspec import zmq -from vllm.distributed.parallel_state import GroupCoordinator from vllm.utils.network_utils import make_zmq_path fake_engine = types.ModuleType("mooncake.engine") @@ -23,6 +22,7 @@ _mock_ascend_config = MagicMock(enable_kv_nz=False) _mock_pp_group = MagicMock(rank_in_group=0, world_size=1) _mock_tp_group = MagicMock(rank_in_group=0, world_size=4) _mock_pcp_group = MagicMock(rank_in_group=0, world_size=1) +_mock_dcp_group = MagicMock(rank_in_group=0, world_size=1) patch('vllm_ascend.distributed.mooncake_connector.get_pp_group', return_value=_mock_pp_group).start() patch('vllm_ascend.distributed.mooncake_connector.get_tp_group', @@ -35,6 +35,7 @@ patch( return_value=0).start() patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group', return_value=_mock_pcp_group).start() +patch('vllm.distributed.parallel_state._DCP', _mock_dcp_group).start() from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402 KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker, @@ -1098,17 +1099,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): self.mock_transfer_engine.get_rpc_port.return_value = 9090 self.mock_transfer_engine.initialize.return_value = 0 self.mock_transfer_engine.register_memory.return_value = 0 - self.mock_dcp_group = MagicMock(spec=GroupCoordinator) - self.mock_dcp_group.rank_in_group = 0 - self.mock_dcp_group.world_size = 1 - self.mock_dcp_group.device_group = MagicMock() - self.mock_dcp = MagicMock() - self.mock_dcp.world_size = 1 - - self.mock_pcp_group = MagicMock(spec=GroupCoordinator) - self.mock_pcp_group.rank_in_group = 0 - self.mock_pcp_group.world_size = 1 - self.mock_pcp_group.device_group = MagicMock() self.patches = [ patch('torch.Tensor.size', return_value=(10, 16, 8, 16)), @@ -1143,13 +1133,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): MagicMock()), patch('vllm_ascend.distributed.mooncake_connector.threading.Event', MagicMock()), - patch('vllm.distributed.parallel_state.get_dcp_group', - return_value=self.mock_dcp_group), - patch('vllm.distributed.parallel_state._DCP', - return_value=self.mock_dcp), - patch( - 'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size', - return_value=1), patch( 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', return_value=MagicMock()),