[Bugfix] Fix bug with establishing the flashcomm2 and pp communication domains. (#4458)
### What this PR does / why we need it? The previous implementation of the flashcomm2 communication domain did not consider pp(pipeline parallel), which caused problems when enabling pp and flashcomm2. This PR fixes this issue. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -24,11 +24,13 @@ def mock_distributed():
|
||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
|
||||
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
|
||||
mock_group.return_value.local_rank = 0
|
||||
mock_group.return_value.device_group = MagicMock()
|
||||
mock_tp_group.return_value.world_size = 4
|
||||
mock_dp_group.return_value.world_size = 2
|
||||
mock_pp_group.return_value.world_size = 2
|
||||
yield
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user