bugfix(MC2): refactor the comm group of MC2 to be compatible with PP (#7291)
### What this PR does / why we need it?
This PR refactors the communication group of MC2 to keep it consistent
with vllm's EP group, making it compatible with PP.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -38,19 +38,16 @@ def init_ascend_model_parallel(
|
||||
global_tp_size = parallel_config.tensor_parallel_size
|
||||
global_dp_size = parallel_config.data_parallel_size
|
||||
global_pp_size = parallel_config.pipeline_parallel_size
|
||||
global_pcp_size = parallel_config.prefill_context_parallel_size
|
||||
|
||||
# The layout of all ranks: ExternalDP * EP
|
||||
# ExternalDP is the data parallel group that is not part of the model,
|
||||
# every dp rank can generate independently (in verl integration).
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size
|
||||
)
|
||||
# TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
|
||||
vllm_all_ranks = torch.arange(world_size).reshape(
|
||||
-1,
|
||||
global_dp_size,
|
||||
global_pp_size,
|
||||
parallel_config.prefill_context_parallel_size,
|
||||
global_pcp_size,
|
||||
global_tp_size,
|
||||
)
|
||||
|
||||
@@ -59,7 +56,6 @@ def init_ascend_model_parallel(
|
||||
global _P_TP
|
||||
assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
|
||||
prefill_tensor_model_parallel_size = pd_tp_ratio
|
||||
pcp_size = parallel_config.prefill_context_parallel_size
|
||||
# divide alltoall groups
|
||||
if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
|
||||
num_head_replica = get_ascend_config().num_head_replica
|
||||
@@ -68,13 +64,16 @@ def init_ascend_model_parallel(
|
||||
group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
|
||||
else:
|
||||
group_ranks = all_ranks.clone().view(
|
||||
global_dp_size * pcp_size, -1, num_head_replica
|
||||
global_dp_size * global_pp_size * global_pcp_size, -1, num_head_replica
|
||||
) # [DP_size, num_head, num_head_replica]
|
||||
group_ranks = group_ranks.permute(0, 2, 1)
|
||||
group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
|
||||
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
|
||||
group_ranks = group_ranks.unsqueeze(-1).view(
|
||||
global_dp_size * pcp_size, num_head_replica, -1, alltoall_group_size
|
||||
global_dp_size * global_pp_size * global_pcp_size,
|
||||
num_head_replica,
|
||||
-1,
|
||||
alltoall_group_size,
|
||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||
group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
@@ -82,10 +81,18 @@ def init_ascend_model_parallel(
|
||||
num = next((i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None)
|
||||
_P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}")
|
||||
|
||||
global _MC2
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
# EP like group ranks
|
||||
group_ranks = (
|
||||
all_ranks.transpose(1, 2)
|
||||
.reshape(
|
||||
-1,
|
||||
global_dp_size * global_pcp_size * global_tp_size,
|
||||
)
|
||||
.unbind(0)
|
||||
)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
|
||||
global _MC2
|
||||
_MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
|
||||
|
||||
if get_ascend_config().eplb_config.dynamic_eplb:
|
||||
@@ -94,6 +101,12 @@ def init_ascend_model_parallel(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb"
|
||||
)
|
||||
|
||||
if get_ascend_config().multistream_overlap_gate:
|
||||
global _FC3_QUANT_X
|
||||
_FC3_QUANT_X = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
|
||||
)
|
||||
|
||||
# Initialize fine-grained TP process groups on Ascend for four components:
|
||||
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
|
||||
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
|
||||
@@ -182,7 +195,7 @@ def init_ascend_model_parallel(
|
||||
# 2. If it is not None, and the module tp_group is same as the global tp_group.
|
||||
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
|
||||
group_ranks = []
|
||||
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
|
||||
pp_group_ranks = all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
|
||||
if module_tp_group_ranks is None:
|
||||
# If it is None, then the TP_size of this shard weight is 1.
|
||||
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
|
||||
@@ -209,17 +222,9 @@ def init_ascend_model_parallel(
|
||||
_SHARD_WEIGHT = create_shard_weight_group(None)
|
||||
else:
|
||||
# For standard tp, use global tp group_ranks
|
||||
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
|
||||
tp_group_ranks = all_ranks.view(-1, global_tp_size)
|
||||
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
|
||||
|
||||
if get_ascend_config().multistream_overlap_gate:
|
||||
global _FC3_QUANT_X
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
_FC3_QUANT_X = init_model_parallel_group(
|
||||
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
|
||||
)
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return _MC2 is not None
|
||||
|
||||
Reference in New Issue
Block a user