bugfix(MC2): refactor the comm group of MC2 to be compatible with PP (#7291)

### What this PR does / why we need it?
This PR refactors the communication group of MC2 to keep it consistent
with vllm's EP group, making it compatible with PP.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-03-23 15:44:21 +08:00
committed by GitHub
parent 8527b49764
commit 71df17f4e6
5 changed files with 571 additions and 89 deletions

View File

@@ -38,19 +38,16 @@ def init_ascend_model_parallel(
global_tp_size = parallel_config.tensor_parallel_size
global_dp_size = parallel_config.data_parallel_size
global_pp_size = parallel_config.pipeline_parallel_size
global_pcp_size = parallel_config.prefill_context_parallel_size
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size
)
# TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
vllm_all_ranks = torch.arange(world_size).reshape(
-1,
global_dp_size,
global_pp_size,
parallel_config.prefill_context_parallel_size,
global_pcp_size,
global_tp_size,
)
@@ -59,7 +56,6 @@ def init_ascend_model_parallel(
global _P_TP
assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
prefill_tensor_model_parallel_size = pd_tp_ratio
pcp_size = parallel_config.prefill_context_parallel_size
# divide alltoall groups
if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
num_head_replica = get_ascend_config().num_head_replica
@@ -68,13 +64,16 @@ def init_ascend_model_parallel(
group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
else:
group_ranks = all_ranks.clone().view(
global_dp_size * pcp_size, -1, num_head_replica
global_dp_size * global_pp_size * global_pcp_size, -1, num_head_replica
) # [DP_size, num_head, num_head_replica]
group_ranks = group_ranks.permute(0, 2, 1)
group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
group_ranks = group_ranks.unsqueeze(-1).view(
global_dp_size * pcp_size, num_head_replica, -1, alltoall_group_size
global_dp_size * global_pp_size * global_pcp_size,
num_head_replica,
-1,
alltoall_group_size,
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
@@ -82,10 +81,18 @@ def init_ascend_model_parallel(
num = next((i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None)
_P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}")
global _MC2
group_ranks = all_ranks.unbind(0)
# EP like group ranks
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
-1,
global_dp_size * global_pcp_size * global_tp_size,
)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
global _MC2
_MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
if get_ascend_config().eplb_config.dynamic_eplb:
@@ -94,6 +101,12 @@ def init_ascend_model_parallel(
group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb"
)
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
_FC3_QUANT_X = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
)
# Initialize fine-grained TP process groups on Ascend for four components:
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
@@ -182,7 +195,7 @@ def init_ascend_model_parallel(
# 2. If it is not None, and the module tp_group is same as the global tp_group.
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
group_ranks = []
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
pp_group_ranks = all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
if module_tp_group_ranks is None:
# If it is None, then the TP_size of this shard weight is 1.
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
@@ -209,17 +222,9 @@ def init_ascend_model_parallel(
_SHARD_WEIGHT = create_shard_weight_group(None)
else:
# For standard tp, use global tp group_ranks
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
tp_group_ranks = all_ranks.view(-1, global_tp_size)
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_FC3_QUANT_X = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
)
def model_parallel_initialized():
return _MC2 is not None