[refactor](UT,PCP,DCP) refactor pcp&dcp patches in UTs (#5505)
### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
64
tests/ut/attention/utils.py
Normal file
64
tests/ut/attention/utils.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from functools import wraps
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from vllm.distributed.parallel_state import all_gather_fake
|
||||
|
||||
|
||||
def patch_distributed_groups(dcp_size=1,
|
||||
dcp_rank=0,
|
||||
pcp_size=1,
|
||||
pcp_rank=0,
|
||||
needs_mocks=True):
|
||||
"""
|
||||
Decorator to patch common distributed group mocks with configuration
|
||||
|
||||
Args:
|
||||
dcp_size: DCP world size (default: 1)
|
||||
dcp_rank: DCP rank (default: 0)
|
||||
pcp_size: PCP world size (default: 1)
|
||||
pcp_rank: PCP rank (default: 0)
|
||||
needs_mocks: Whether to pass mock objects as the first arguments
|
||||
after 'self' to the decorated function.
|
||||
If True, the decorated function receives:
|
||||
func(self, mock_all_to_all_single, mock_dcp, mock_pcp, *args, **kwargs)
|
||||
If False, mocks are not passed and function receives:
|
||||
func(self, *args, **kwargs)
|
||||
(default: True)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
|
||||
@wraps(func)
|
||||
@patch('torch.distributed.all_to_all_single')
|
||||
@patch('vllm.distributed.parallel_state._PCP')
|
||||
@patch('vllm.distributed.parallel_state._DCP')
|
||||
def wrapper(self, mock_dcp, mock_pcp, mock_all_to_all_single, *args,
|
||||
**kwargs):
|
||||
mock_dcp.rank_in_group = dcp_rank
|
||||
mock_dcp.world_size = dcp_size
|
||||
mock_dcp.device_group = MagicMock()
|
||||
|
||||
mock_dcp.all_gather = MagicMock()
|
||||
mock_dcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
||||
input_, dim, mock_dcp.world_size, "mock_dcp_group")
|
||||
|
||||
mock_pcp.rank_in_group = pcp_rank
|
||||
mock_pcp.world_size = pcp_size
|
||||
mock_pcp.device_group = MagicMock()
|
||||
|
||||
mock_pcp.all_gather = MagicMock()
|
||||
mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
||||
input_, dim, mock_pcp.world_size, "mock_pcp_group")
|
||||
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_(
|
||||
input)
|
||||
|
||||
if needs_mocks:
|
||||
return func(self, mock_all_to_all_single, mock_dcp, mock_pcp,
|
||||
*args, **kwargs)
|
||||
else:
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
Reference in New Issue
Block a user