### What this PR does / why we need it?
Refactor PCP & DCP patches in UTs: Merge and reuse communication groups
and communication function patches to reduce code duplication.
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
from functools import wraps
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from vllm.distributed.parallel_state import all_gather_fake
|
|
|
|
|
|
def patch_distributed_groups(dcp_size=1,
|
|
dcp_rank=0,
|
|
pcp_size=1,
|
|
pcp_rank=0,
|
|
needs_mocks=True):
|
|
"""
|
|
Decorator to patch common distributed group mocks with configuration
|
|
|
|
Args:
|
|
dcp_size: DCP world size (default: 1)
|
|
dcp_rank: DCP rank (default: 0)
|
|
pcp_size: PCP world size (default: 1)
|
|
pcp_rank: PCP rank (default: 0)
|
|
needs_mocks: Whether to pass mock objects as the first arguments
|
|
after 'self' to the decorated function.
|
|
If True, the decorated function receives:
|
|
func(self, mock_all_to_all_single, mock_dcp, mock_pcp, *args, **kwargs)
|
|
If False, mocks are not passed and function receives:
|
|
func(self, *args, **kwargs)
|
|
(default: True)
|
|
"""
|
|
|
|
def decorator(func):
|
|
|
|
@wraps(func)
|
|
@patch('torch.distributed.all_to_all_single')
|
|
@patch('vllm.distributed.parallel_state._PCP')
|
|
@patch('vllm.distributed.parallel_state._DCP')
|
|
def wrapper(self, mock_dcp, mock_pcp, mock_all_to_all_single, *args,
|
|
**kwargs):
|
|
mock_dcp.rank_in_group = dcp_rank
|
|
mock_dcp.world_size = dcp_size
|
|
mock_dcp.device_group = MagicMock()
|
|
|
|
mock_dcp.all_gather = MagicMock()
|
|
mock_dcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
|
input_, dim, mock_dcp.world_size, "mock_dcp_group")
|
|
|
|
mock_pcp.rank_in_group = pcp_rank
|
|
mock_pcp.world_size = pcp_size
|
|
mock_pcp.device_group = MagicMock()
|
|
|
|
mock_pcp.all_gather = MagicMock()
|
|
mock_pcp.all_gather.side_effect = lambda input_, dim: all_gather_fake(
|
|
input_, dim, mock_pcp.world_size, "mock_pcp_group")
|
|
|
|
mock_all_to_all_single.side_effect = lambda output, input, *a, **kw: output.copy_(
|
|
input)
|
|
|
|
if needs_mocks:
|
|
return func(self, mock_all_to_all_single, mock_dcp, mock_pcp,
|
|
*args, **kwargs)
|
|
else:
|
|
return func(self, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
return decorator
|