[refactor](UT,PCP,DCP) refactor pcp&dcp patches in UTs (#5505)
### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from tests.ut.attention.utils import patch_distributed_groups
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
@@ -159,29 +160,18 @@ def get_chunk_metadata(pcp_size, dcp_size, num_prefills, num_decodes,
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
||||
mock_tp, mock_get_dcp_size, mock_dcp, mock_pcp):
|
||||
mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
mock_tp.rank_in_group = MagicMock()
|
||||
mock_tp.device_group = MagicMock()
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = MagicMock()
|
||||
mock_dcp.device_group = MagicMock()
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = MagicMock()
|
||||
mock_pcp.device_group = MagicMock()
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
@@ -252,12 +242,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(self.impl.pcp_size, 2)
|
||||
self.assertEqual(self.impl.dcp_size, 2)
|
||||
|
||||
@patch('vllm_ascend.attention.mla_cp.get_dcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_mla_preprocess_dcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_dcp_group):
|
||||
mock_maybe_all_gather_and_maybe_unpad):
|
||||
|
||||
self.impl.num_kv_heads = 1
|
||||
self.impl.num_heads = 16
|
||||
@@ -278,14 +267,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl.qk_rope_head_dim)
|
||||
kv_cache = (kv_cache0, kv_cache1)
|
||||
|
||||
mock_dcp_group = MagicMock()
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_dcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dcp_group.return_value = mock_dcp_group
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.num_decodes = 2
|
||||
attn_metadata.num_prefills = 0
|
||||
@@ -337,12 +318,11 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertIsNone(prefill_res)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.mla_cp.get_pcp_group')
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_mla_preprocess_pcp(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad,
|
||||
mock_get_pcp_group,
|
||||
mock_npu_reshape_and_cache):
|
||||
self.impl.num_kv_heads = 1
|
||||
self.impl.num_heads = 16
|
||||
@@ -363,14 +343,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.impl.qk_rope_head_dim)
|
||||
kv_cache = (kv_cache0, kv_cache1)
|
||||
|
||||
mock_pcp_group = MagicMock()
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_pcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_pcp_group.return_value = mock_pcp_group
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.num_decodes = 0
|
||||
attn_metadata.num_prefills = 2
|
||||
@@ -451,10 +423,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertIsNone(decode_res)
|
||||
self.assertIsNotNone(prefill_res)
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
def test_process_attn_out_lse(self, mock_all_to_all_single, mock_pcp):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_process_attn_out_lse(self):
|
||||
self.impl.dcp_size = 2
|
||||
self.impl.pcp_size = 2
|
||||
|
||||
@@ -465,14 +435,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
attn_output = torch.randn(B, N, self.impl.kv_lora_rank)
|
||||
softmax_lse = torch.randn(B, N, 1)
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
|
||||
def make_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
|
||||
|
||||
decode_metadata = MagicMock()
|
||||
decode_metadata.actual_seq_lengths_q = MagicMock()
|
||||
decode_metadata.seq_lens_list = MagicMock()
|
||||
@@ -486,16 +448,13 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], self.impl.kv_lora_rank + 1)
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
|
||||
@patch("torch_npu.atb.npu_multi_head_latent_attention")
|
||||
@patch('torch_npu.npu_attention_update')
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
|
||||
mock_npu_multi_head_latent_attention,
|
||||
mock_get_forward_context,
|
||||
mock_all_to_all_single, mock_pcp):
|
||||
mock_get_forward_context):
|
||||
self.impl.dcp_size = 2
|
||||
self.impl.pcp_size = 2
|
||||
self.impl.num_kv_heads = 1
|
||||
@@ -531,14 +490,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
]
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
|
||||
def make_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))
|
||||
|
||||
self.impl._v_up_proj = MagicMock()
|
||||
self.impl._v_up_proj.return_value = torch.randn(
|
||||
B, self.impl.v_head_dim)
|
||||
@@ -549,17 +500,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], self.impl.v_head_dim)
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
def test_compute_prefill_context_with_dcp_pcp(self, mock_ring, mock_load,
|
||||
mock_dcp, mock_pcp):
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_compute_prefill_context_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
||||
mock_pcp, mock_ring,
|
||||
mock_load):
|
||||
|
||||
def mock_ring_attn(q_nope, q_rope, k_nope, k_rope, value, mask, seqlen,
|
||||
head_num, kv_head_num, pre_out, prev_lse, qk_scale,
|
||||
@@ -620,10 +566,8 @@ class TestAscendMLAImpl(TestBase):
|
||||
torch.ones(10, 10, dtype=torch.float16), 1)
|
||||
for test_case in test_cases:
|
||||
pcp_size, dcp_size, nums_tokens_per_rank, nums_all_rank_context, num_prefills, num_decodes, num_seqs, cp_local_block_size, num_computed_tokens_of_pcp_dcp = test_case
|
||||
mock_dcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(dcp_size))
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(pcp_size))
|
||||
mock_dcp.world_size = dcp_size
|
||||
mock_pcp.world_size = pcp_size
|
||||
assert len(nums_tokens_per_rank) == len(nums_all_rank_context)
|
||||
nums_context_per_rank = []
|
||||
for num_all_rank_context in nums_all_rank_context:
|
||||
@@ -687,18 +631,9 @@ class TestAscendMLAImpl(TestBase):
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_reorg_kvcache_with_dcp_pcp(self, mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_reorg_kvcache_with_dcp_pcp(self, mock_all2all, mock_dcp,
|
||||
mock_pcp):
|
||||
BLOCK_SIZE = 128 # fixed
|
||||
max_model_len = 4096
|
||||
max_num_seqs = 25
|
||||
@@ -714,11 +649,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
if pcp_size * dcp_size == 1:
|
||||
continue
|
||||
self.impl.dcp_size = dcp_size
|
||||
mock_dcp.world_size = dcp_size
|
||||
mock_dcp.all_gather.reset_mock()
|
||||
self.impl.pcp_size = pcp_size
|
||||
mock_dcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(dcp_size))
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(pcp_size))
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.all_gather.reset_mock()
|
||||
|
||||
chunked_prefill_workspace_size = min(
|
||||
max(8 * max_model_len, 4 * max_num_seqs * BLOCK_SIZE),
|
||||
128 * 1024)
|
||||
@@ -918,17 +854,17 @@ class TestAscendMLAImpl(TestBase):
|
||||
1 + (kv_with_q_tail_nomask_idx.shape[0] != 0))
|
||||
mock_npu_ring_mla.reset_mock()
|
||||
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
def test_process_attn_out_lse_with_dcp_pcp(self, mock_pcp,
|
||||
mock_all_to_all):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_process_attn_out_lse_with_dcp_pcp(self, mock_all_to_all, mock_dcp,
|
||||
mock_pcp):
|
||||
B, H, D = 4, self.impl.num_heads, self.impl.v_head_dim # total: [4, 4, 8]
|
||||
test_cases = [(1, 1), (1, 2), (2, 1), (2, 2), (4, 4)]
|
||||
for test_case in test_cases:
|
||||
print(test_case)
|
||||
self.impl.dcp_size = test_case[0]
|
||||
self.impl.pcp_size = test_case[1]
|
||||
mock_dcp.world_size = test_case[0]
|
||||
mock_pcp.world_size = test_case[1]
|
||||
# Inputs
|
||||
attn_output = torch.randn(B, H, D)
|
||||
softmax_lse = torch.randn(B, H, 1)
|
||||
@@ -936,17 +872,6 @@ class TestAscendMLAImpl(TestBase):
|
||||
decode_meta = MagicMock()
|
||||
decode_meta.batch_seq_mask = batch_seq_mask
|
||||
|
||||
def mock_all_to_all_side_effect(output, input, group=None):
|
||||
output.copy_(input)
|
||||
|
||||
mock_all_to_all.side_effect = mock_all_to_all_side_effect
|
||||
|
||||
def mock_all_gather(ws):
|
||||
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)
|
||||
|
||||
mock_pcp.all_gather = MagicMock(
|
||||
side_effect=mock_all_gather(self.impl.pcp_size))
|
||||
|
||||
result = self.impl._process_attn_out_lse(attn_output, softmax_lse,
|
||||
decode_meta)
|
||||
# [PCP * S, DCP * H, D + 1]
|
||||
|
||||
Reference in New Issue
Block a user