from typing import Optional import torch from vllm.config import ParallelConfig, get_current_vllm_config from vllm.distributed.parallel_state import (GroupCoordinator, get_dp_group, get_pp_group, get_tp_group, get_world_group, init_model_parallel_group) from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import (enable_sp, flashcomm2_enable, flashcomm2_o_shared_enabled) # Currently, mc2 op need their own group coordinator. _MC2: Optional[GroupCoordinator] = None # Module specific tensor parallel groups _MLP_TP: Optional[GroupCoordinator] = None _OTP: Optional[GroupCoordinator] = None _LMTP: Optional[GroupCoordinator] = None _EMBED_TP: Optional[GroupCoordinator] = None # flashcomm specific groups _FLASHCOMM2_OTP: Optional[GroupCoordinator] = None _FLASHCOMM2_ODP: Optional[GroupCoordinator] = None _FC3_QUANT_X: Optional[GroupCoordinator] = None # shared_weight across rank groups _SHARED_WEIGHT: Optional[GroupCoordinator] = None _P_TP: Optional[GroupCoordinator] = None def init_ascend_model_parallel(parallel_config: ParallelConfig, ): if model_parallel_initialized(): return assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() backend = torch.distributed.get_backend(get_world_group().device_group) vllm_config = get_current_vllm_config() global_tp_size = parallel_config.tensor_parallel_size global_dp_size = parallel_config.data_parallel_size global_pp_size = parallel_config.pipeline_parallel_size # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). all_ranks = torch.arange(world_size).reshape( -1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size) pd_tp_ratio = get_ascend_config().pd_tp_ratio pd_head_ratio = get_ascend_config().pd_head_ratio global _P_TP assert _P_TP is None, ( "distributed prefill tensor parallel group is already initialized") prefill_tensor_model_parallel_size = pd_tp_ratio # divide alltoall groups if pd_head_ratio > 1 and get_current_vllm_config( ).kv_transfer_config.is_kv_producer: num_head_replica = get_ascend_config().num_head_replica remote_tp_size = global_tp_size // pd_tp_ratio if num_head_replica <= 1: group_ranks = all_ranks.view( -1, prefill_tensor_model_parallel_size).unbind(0) else: group_ranks = all_ranks.clone().view( global_dp_size, -1, num_head_replica) # [DP_size, num_head, num_head_replica] group_ranks = group_ranks.permute(0, 2, 1) group_ranks = group_ranks.reshape( -1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head] alltoall_group_size = group_ranks.size(-1) // remote_tp_size group_ranks = group_ranks.unsqueeze(-1).view( global_dp_size, num_head_replica, -1, alltoall_group_size ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] local_rank = get_world_group().local_rank num = next( (i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None) _P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}") global _MC2 group_ranks = all_ranks.unbind(0) group_ranks = [x.tolist() for x in group_ranks] _MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2") # Initialize specialized tensor parallel (TP) process groups for fine-grained model parallelism # on Ascend hardware. This enables independent TP configurations for three critical components: # 1. ** LM Head **: # The final linear layer that maps hidden states to vocabulary logits. # Controlled by `lmhead_tensor_parallel_size`. # 2. ** o_proj **: # The output projection in attention blocks (e.g., in Multi-Head Attention). # Controlled by `oproj_tensor_parallel_size`. # 3. ** Embedding **: # The token embedding table at the input and/or output of the model. # Controlled by `embedding_tensor_parallel_size`. # 4. ** MLP **: # The feed-forward network layers within transformer blocks. # Controlled by `mlp_tensor_parallel_size`. _group_cache = {} def _create_or_get_group(group_size: int, group_name: str) -> GroupCoordinator: if group_size is None: return None if group_size not in _group_cache: rank_grid = torch.arange(world_size).reshape( global_pp_size, global_dp_size, global_tp_size) num_chunks = global_dp_size // group_size group_ranks = [] for pp_idx in range(global_pp_size): stage_ranks = rank_grid[pp_idx] # (dp, tp) for chunk in range(num_chunks): for tp_idx in range(global_tp_size): group = stage_ranks[chunk * group_size:(chunk + 1) * group_size, tp_idx].tolist() group_ranks.append(group) pg = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=group_name) _group_cache[group_size] = pg return _group_cache[group_size] otp_size = get_ascend_config( ).finegrained_tp_config.oproj_tensor_parallel_size lmhead_tp_size = get_ascend_config( ).finegrained_tp_config.lmhead_tensor_parallel_size embedding_tp_size = get_ascend_config( ).finegrained_tp_config.embedding_tensor_parallel_size mlp_tp_size = get_ascend_config( ).finegrained_tp_config.embedding_tensor_parallel_size global _OTP, _LMTP, _EMBED_TP if otp_size > 0: _OTP = _create_or_get_group(otp_size, "otp") if lmhead_tp_size > 0: _LMTP = _create_or_get_group(lmhead_tp_size, "lmheadtp") if embedding_tp_size > 0: _EMBED_TP = _create_or_get_group(embedding_tp_size, "emtp") if mlp_tp_size > 0: _MLP_TP = _create_or_get_group(mlp_tp_size, "mlptp") def _create_shared_weight_group(group_name: str) -> GroupCoordinator: #This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference group_ranks = [] for pp_idx in range(global_pp_size): group = [] for dp_idx in range(global_dp_size): base = (dp_idx * global_pp_size + pp_idx) * global_tp_size for i in range(global_tp_size): global_rank = base + i group.append(global_rank) group_ranks.append(group) return init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=group_name) global _SHARED_WEIGHT # TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97 is_ds_v32 = hasattr(vllm_config.model_config.hf_config, "index_topk") if enable_sp() and is_ds_v32 and _SHARED_WEIGHT is None: _SHARED_WEIGHT = _create_shared_weight_group("CP_shared_weight") # TODO: Extract and unify the logic across different communication group. if flashcomm2_enable(): flashcomm2_otp_size = get_ascend_config( ).flashcomm2_oproj_tensor_parallel_size global_tp_size = get_tp_group().world_size global_dp_size = get_dp_group().world_size global_pp_size = get_pp_group().world_size num_fc2_oproj_tensor_parallel_groups: int = (global_tp_size // flashcomm2_otp_size) global _FLASHCOMM2_OTP global _FLASHCOMM2_ODP _FLASHCOMM2_OTP = None _FLASHCOMM2_ODP = get_tp_group() if flashcomm2_otp_size > 1: otp_group_ranks = [] odp_group_ranks: list[list[int]] = [ [] for _ in range(flashcomm2_otp_size * global_dp_size * global_pp_size) ] for dp_group_index in range(global_dp_size): for pp_group_index in range(global_pp_size): dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index tp_base_rank = dp_pp_serial_index * global_tp_size odp_base_index = dp_pp_serial_index * flashcomm2_otp_size for i in range(num_fc2_oproj_tensor_parallel_groups): ranks = [] for j in range(flashcomm2_otp_size): tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups assert tp_local_rank < global_tp_size global_rank = tp_base_rank + tp_local_rank ranks.append(global_rank) odp_group_index = odp_base_index + j odp_group_ranks[odp_group_index].append( global_rank) otp_group_ranks.append(ranks) _FLASHCOMM2_OTP = init_model_parallel_group( otp_group_ranks, get_world_group().local_rank, backend, group_name="flashcomm2_otp") _FLASHCOMM2_ODP = init_model_parallel_group( odp_group_ranks, get_world_group().local_rank, backend, group_name="flashcomm2_odp") # Create shared weight group for flashcomm2 oproj if flashcomm2_o_shared_enabled(): assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1" if _SHARED_WEIGHT is None: _SHARED_WEIGHT = _create_shared_weight_group( "flashcomm2_o_shared") if get_ascend_config().multistream_overlap_gate: global _FC3_QUANT_X group_ranks = all_ranks.unbind(0) group_ranks = [x.tolist() for x in group_ranks] _FC3_QUANT_X = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x") def model_parallel_initialized(): return (_MC2 is not None) def get_mc2_group() -> GroupCoordinator: assert _MC2 is not None, ("mc2 group is not initialized") return _MC2 def get_mlp_tp_group() -> GroupCoordinator: assert _MLP_TP is not None, ("mlp group is not initialized") return _MLP_TP def get_otp_group() -> GroupCoordinator: assert _OTP is not None, ( "output tensor parallel group is not initialized") return _OTP def get_lmhead_tp_group() -> GroupCoordinator: assert _LMTP is not None, ( "lm head tensor parallel group is not initialized") return _LMTP def get_embed_tp_group() -> GroupCoordinator: assert _EMBED_TP is not None, ("emtp group is not initialized") return _EMBED_TP def get_flashcomm2_otp_group() -> GroupCoordinator: return _FLASHCOMM2_OTP def get_flashcomm2_odp_group() -> GroupCoordinator: assert _FLASHCOMM2_ODP is not None, ( "output data parallel group for flashcomm2 is not initialized") return _FLASHCOMM2_ODP def get_shared_weight_group() -> GroupCoordinator: assert _SHARED_WEIGHT is not None, ( "output shared weight parallel group for flashcomm2 is not initialized" ) return _SHARED_WEIGHT def get_p_tp_group() -> GroupCoordinator: assert _P_TP is not None, ( "distributed prefill tensor parallel group is not initialized") return _P_TP def get_fc3_quant_x_group() -> GroupCoordinator: assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized") return _FC3_QUANT_X def destroy_ascend_model_parallel(): global _MC2 if _MC2: _MC2.destroy() _MC2 = None global _MLP_TP if _MLP_TP: _MLP_TP.destroy() _MLP_TP = None global _LMTP if _LMTP: _LMTP.destroy() _LMTP = None global _EMBED_TP if _EMBED_TP: _EMBED_TP.destroy() _EMBED_TP = None global _OTP if _OTP: _OTP.destroy() _OTP = None global _P_TP if _P_TP: _P_TP.destroy() _P_TP = None global _FLASHCOMM2_OTP if _FLASHCOMM2_OTP and get_ascend_config( ).flashcomm2_oproj_tensor_parallel_size != 1: _FLASHCOMM2_OTP.destroy() _FLASHCOMM2_OTP = None global _FLASHCOMM2_ODP if _FLASHCOMM2_ODP and get_ascend_config( ).flashcomm2_oproj_tensor_parallel_size != 1: _FLASHCOMM2_ODP.destroy() _FLASHCOMM2_ODP = None global _SHARED_WEIGHT if _SHARED_WEIGHT: _SHARED_WEIGHT.destroy() _SHARED_WEIGHT = None global _FC3_QUANT_X if _FC3_QUANT_X: _FC3_QUANT_X.destroy() _FC3_QUANT_X = None