[refactor](UT,PCP,DCP) refactor pcp&dcp patches in UTs (#5505)
### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -1,71 +1,19 @@
|
||||
from functools import wraps
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator, all_gather_fake
|
||||
|
||||
from tests.ut.attention.utils import patch_distributed_groups
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_cp import AscendAttentionCPImpl
|
||||
from vllm_ascend.attention.attention_v1 import (AscendMetadata,
|
||||
AscendMetadataForPrefill)
|
||||
|
||||
|
||||
def patch_distributed_groups(dcp_size=1, dcp_rank=0, pcp_size=1, pcp_rank=0):
|
||||
"""
|
||||
Decorator to patch common distributed group mocks with configuration
|
||||
|
||||
Args:
|
||||
dcp_size: DCP world size (default: 1)
|
||||
dcp_rank: DCP rank (default: 0)
|
||||
pcp_size: PCP world size (default: 1)
|
||||
pcp_rank: PCP rank (default: 0)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
def wrapper(self, mock_pcp, mock_all_to_all_single, *args, **kwargs):
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.device_group = MagicMock()
|
||||
mock_pcp.all_gather = MagicMock()
|
||||
mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
||||
input_, dim, pcp_size, "mock")
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_(
|
||||
input)
|
||||
|
||||
return func(self, mock_all_to_all_single, mock_pcp, *args,
|
||||
**kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_pcp):
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = 0
|
||||
mock_dcp.device_group = MagicMock()
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
mock_pcp.device_group = MagicMock()
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
@@ -134,41 +82,12 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 4)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP')
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
@patch("torch.distributed.all_gather")
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch('vllm_ascend.attention.attention_cp.get_forward_context')
|
||||
def test_forward_decode_pcp_dcp(self, mock_get_forward_context,
|
||||
mock_all_to_all_single, mock_all_gather,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_pcp_group):
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
mock_dcp.world_size = 2
|
||||
mock_dcp.rank_in_group = 0
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 2
|
||||
dcp_group.device_group = MagicMock()
|
||||
dcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 2
|
||||
mock_pcp.rank_in_group = 0
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 2
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_pcp_group.return_value = pcp_group
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2)
|
||||
def test_forward_decode_pcp_dcp(self, mock_all2all, mock_dcp, mock_pcp,
|
||||
mock_get_forward_context,
|
||||
mock_npu_fused_infer_attention_score):
|
||||
query = torch.randn(2, 4, 128)
|
||||
self.impl.key_cache = torch.randn(100, 128, 1, 128)
|
||||
self.impl.value_cache = torch.randn(100, 128, 1, 128)
|
||||
@@ -185,15 +104,6 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
mock_get_forward_context.return_value = MagicMock(capturing=False)
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
|
||||
def mock_all_gather_func1(tensor_list, tensor, group=None):
|
||||
tensor_list[0] = tensor
|
||||
tensor_list[1] = tensor.clone()
|
||||
|
||||
mock_all_gather.side_effect = mock_all_gather_func1
|
||||
|
||||
def mock_npu_fused_infer_attention_score_func(query, k_nope, value,
|
||||
**common_kwargs):
|
||||
mock_output = torch.randn_like(query)
|
||||
@@ -213,25 +123,10 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 4)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('vllm_ascend.attention.attention_cp.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP')
|
||||
def test_prefill_query_all_gather(self, mock_dcp, mock_get_dcp_group,
|
||||
mock_pcp, mock_get_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_prefill_query_all_gather(self):
|
||||
query = torch.randn(2, 4, 128)
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.prefill = MagicMock()
|
||||
attn_metadata.prefill.chunked_context = MagicMock()
|
||||
@@ -243,9 +138,9 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(output.shape[1], 8)
|
||||
self.assertEqual(output.shape[2], 128)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('torch.ops.npu.npu_fused_infer_attention_score')
|
||||
def test_compute_prefill_context(self, mock_npu_attention, mock_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_compute_prefill_context(self, mock_npu_attention):
|
||||
|
||||
block_num = 100
|
||||
block_size = 128
|
||||
@@ -284,13 +179,6 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.impl._load_kv_for_chunk = MagicMock()
|
||||
self.impl._load_kv_for_chunk.side_effect = mock_load_kv_for_chunk
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_pcp_group.return_value = pcp_group
|
||||
|
||||
mock_npu_attention.return_value = torch.randn(batch_size, num_heads,
|
||||
head_size), torch.randn(
|
||||
batch_size,
|
||||
@@ -341,11 +229,9 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
self.assertEqual(value.shape[1], num_heads)
|
||||
self.assertEqual(value.shape[2], head_size)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_cp.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_reshape_and_cache(self, mock_npu_reshape_and_cache, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=2, needs_mocks=False)
|
||||
def test_reshape_and_cache(self, mock_npu_reshape_and_cache):
|
||||
num_tokens = 4
|
||||
block_num = 100
|
||||
block_size = 128
|
||||
@@ -369,13 +255,6 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
key = torch.randn(num_tokens, num_heads, head_size)
|
||||
value = torch.randn(num_tokens, num_heads, head_size)
|
||||
|
||||
def mock_all_gather_func(tensor, dim):
|
||||
return torch.cat([tensor, tensor], dim=dim)
|
||||
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.all_gather = mock_all_gather_func
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
key, value = self.impl.reshape_and_cache(key, value, kv_cache,
|
||||
attn_metadata)
|
||||
self.assertEqual(key.shape[0], num_tokens * self.impl.pcp_size)
|
||||
@@ -388,30 +267,8 @@ class TestAscendAttentionCPImpl(TestBase):
|
||||
|
||||
class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state.get_pcp_group')
|
||||
@patch('vllm.distributed.parallel_state._PCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch('vllm.distributed.parallel_state.get_dcp_group')
|
||||
@patch('vllm.distributed.parallel_state._DCP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_decode_context_model_parallel_world_size",
|
||||
return_value=1)
|
||||
def setUp(self, mock_get_dcp_size, mock_dcp, mock_get_dcp_group, mock_pcp,
|
||||
mock_get_pcp_group):
|
||||
mock_dcp.world_size = 1
|
||||
dcp_group = MagicMock(spec=GroupCoordinator)
|
||||
dcp_group.rank_in_group = 0
|
||||
dcp_group.world_size = 1
|
||||
dcp_group.device_group = MagicMock()
|
||||
mock_get_dcp_group.return_value = dcp_group
|
||||
|
||||
mock_pcp.world_size = 1
|
||||
pcp_group = MagicMock(spec=GroupCoordinator)
|
||||
pcp_group.rank_in_group = 0
|
||||
pcp_group.world_size = 1
|
||||
pcp_group.device_group = MagicMock()
|
||||
mock_get_pcp_group.return_value = pcp_group
|
||||
|
||||
@patch_distributed_groups(needs_mocks=False)
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
@@ -443,6 +300,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
self.impl.pcp_size = 1
|
||||
self.impl.dcp_size = 1
|
||||
|
||||
self.batch_size = 2
|
||||
|
||||
@@ -730,7 +588,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch_distributed_groups(dcp_size=2, pcp_size=3)
|
||||
def test_update_chunk_attn_out_lse_dcp2_pcp3(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
mock_dcp, mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -759,7 +617,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch_distributed_groups(dcp_size=2)
|
||||
def test_update_chunk_attn_out_lse_dcp2_pcp1(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
mock_dcp, mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
@@ -789,7 +647,7 @@ class TestUpdateNpuAttnOutLse(TestBase):
|
||||
|
||||
@patch_distributed_groups(pcp_size=2)
|
||||
def test_update_chunk_attn_out_lse_dcp1_pcp2(self, mock_all_to_all_single,
|
||||
mock_pcp):
|
||||
mock_dcp, mock_pcp):
|
||||
# Mock input data
|
||||
prefix_chunk_output = torch.randn(2, 4, 8)
|
||||
prefix_chunk_lse = torch.randn(2, 4, 1)
|
||||
|
||||
Reference in New Issue
Block a user