[Bugfix] Fix bug with establishing the flashcomm2 and pp communication domains. (#4458)
### What this PR does / why we need it? The previous implementation of the flashcomm2 communication domain did not consider pp(pipeline parallel), which caused problems when enabling pp and flashcomm2. This PR fixes this issue. - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: zzhx1 <zzh_201018@outlook.com> Co-authored-by: Levi-JQ <yujinqi2@huawei.com>
This commit is contained in:
@@ -24,11 +24,13 @@ def mock_distributed():
|
|||||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group, \
|
||||||
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
|
patch('vllm_ascend.distributed.parallel_state.get_tp_group') as mock_tp_group, \
|
||||||
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group:
|
patch('vllm_ascend.distributed.parallel_state.get_dp_group') as mock_dp_group, \
|
||||||
|
patch('vllm_ascend.distributed.parallel_state.get_pp_group') as mock_pp_group:
|
||||||
mock_group.return_value.local_rank = 0
|
mock_group.return_value.local_rank = 0
|
||||||
mock_group.return_value.device_group = MagicMock()
|
mock_group.return_value.device_group = MagicMock()
|
||||||
mock_tp_group.return_value.world_size = 4
|
mock_tp_group.return_value.world_size = 4
|
||||||
mock_dp_group.return_value.world_size = 2
|
mock_dp_group.return_value.world_size = 2
|
||||||
|
mock_pp_group.return_value.world_size = 2
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
from vllm.config import ParallelConfig, get_current_vllm_config
|
from vllm.config import ParallelConfig, get_current_vllm_config
|
||||||
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group,
|
||||||
get_tp_group, get_world_group,
|
get_pp_group, get_tp_group,
|
||||||
|
get_world_group,
|
||||||
init_model_parallel_group)
|
init_model_parallel_group)
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
@@ -185,6 +186,7 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
).flashcomm2_oproj_tensor_parallel_size
|
).flashcomm2_oproj_tensor_parallel_size
|
||||||
global_tp_size = get_tp_group().world_size
|
global_tp_size = get_tp_group().world_size
|
||||||
global_dp_size = get_dp_group().world_size
|
global_dp_size = get_dp_group().world_size
|
||||||
|
global_pp_size = get_pp_group().world_size
|
||||||
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
|
num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size //
|
||||||
flashcomm2_otp_size)
|
flashcomm2_otp_size)
|
||||||
|
|
||||||
@@ -197,18 +199,27 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
if flashcomm2_otp_size > 1:
|
if flashcomm2_otp_size > 1:
|
||||||
otp_group_ranks = []
|
otp_group_ranks = []
|
||||||
odp_group_ranks: list[list[int]] = [
|
odp_group_ranks: list[list[int]] = [
|
||||||
[] for _ in range(flashcomm2_otp_size * global_dp_size)
|
[] for _ in range(flashcomm2_otp_size * global_dp_size *
|
||||||
|
global_pp_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
for dp_group_index in range(global_dp_size):
|
for dp_group_index in range(global_dp_size):
|
||||||
for i in range(num_fc2_oproj_tensor_parallel_groups):
|
for pp_group_index in range(global_pp_size):
|
||||||
ranks = []
|
dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index
|
||||||
for j in range(flashcomm2_otp_size):
|
tp_base_rank = dp_pp_serial_index * global_tp_size
|
||||||
rank_idx = dp_group_index * global_tp_size + i + j * num_fc2_oproj_tensor_parallel_groups
|
odp_base_index = dp_pp_serial_index * flashcomm2_otp_size
|
||||||
ranks.append(rank_idx)
|
|
||||||
odp_group_index = dp_group_index * flashcomm2_otp_size + j
|
for i in range(num_fc2_oproj_tensor_parallel_groups):
|
||||||
odp_group_ranks[odp_group_index].append(rank_idx)
|
ranks = []
|
||||||
otp_group_ranks.append(ranks)
|
for j in range(flashcomm2_otp_size):
|
||||||
|
tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
|
||||||
|
assert tp_local_rank < global_tp_size
|
||||||
|
global_rank = tp_base_rank + tp_local_rank
|
||||||
|
ranks.append(global_rank)
|
||||||
|
|
||||||
|
odp_group_index = odp_base_index + j
|
||||||
|
odp_group_ranks[odp_group_index].append(
|
||||||
|
global_rank)
|
||||||
|
otp_group_ranks.append(ranks)
|
||||||
|
|
||||||
_FLASHCOMM2_OTP = init_model_parallel_group(
|
_FLASHCOMM2_OTP = init_model_parallel_group(
|
||||||
otp_group_ranks,
|
otp_group_ranks,
|
||||||
|
|||||||
Reference in New Issue
Block a user