2025-07-28 14:06:20 +08:00
from typing import Optional
import torch
2025-10-11 11:22:23 +08:00
from vllm . config import ParallelConfig , get_current_vllm_config
2025-11-10 11:01:45 +08:00
from vllm . distributed . parallel_state import ( GroupCoordinator , get_dp_group ,
2025-12-01 15:56:22 +08:00
get_pp_group , get_tp_group ,
get_world_group ,
2025-07-28 14:06:20 +08:00
init_model_parallel_group )
2025-08-29 11:41:21 +08:00
from vllm_ascend . ascend_config import get_ascend_config
2025-12-11 12:43:04 +08:00
from vllm_ascend . utils import ( enable_sp , flashcomm2_enable ,
flashcomm2_o_shared_enabled )
2025-08-21 09:22:07 +08:00
2025-07-28 14:06:20 +08:00
# Currently, mc2 op need their own group coordinator.
_MC2 : Optional [ GroupCoordinator ] = None
2025-12-12 14:41:20 +08:00
# Module specific tensor parallel groups
2025-08-21 09:22:07 +08:00
_MLP_TP : Optional [ GroupCoordinator ] = None
2025-09-07 10:31:32 +08:00
_OTP : Optional [ GroupCoordinator ] = None
2025-08-29 11:41:21 +08:00
_LMTP : Optional [ GroupCoordinator ] = None
2025-12-12 14:41:20 +08:00
_EMBED_TP : Optional [ GroupCoordinator ] = None
2025-12-14 09:34:13 +08:00
# flashcomm specific groups
2025-11-10 11:01:45 +08:00
_FLASHCOMM2_OTP : Optional [ GroupCoordinator ] = None
_FLASHCOMM2_ODP : Optional [ GroupCoordinator ] = None
2025-12-14 09:34:13 +08:00
_FC3_QUANT_X : Optional [ GroupCoordinator ] = None
2025-09-30 15:10:29 +08:00
2025-12-12 14:41:20 +08:00
# shared_weight across rank groups
_SHARED_WEIGHT : Optional [ GroupCoordinator ] = None
2025-09-30 15:10:29 +08:00
2025-12-12 14:41:20 +08:00
_P_TP : Optional [ GroupCoordinator ] = None
2025-07-28 14:06:20 +08:00
def init_ascend_model_parallel ( parallel_config : ParallelConfig , ) :
if model_parallel_initialized ( ) :
return
assert torch . distributed . is_initialized ( )
world_size = torch . distributed . get_world_size ( )
backend = torch . distributed . get_backend ( get_world_group ( ) . device_group )
2025-12-11 12:43:04 +08:00
vllm_config = get_current_vllm_config ( )
2025-12-12 14:41:20 +08:00
global_tp_size = parallel_config . tensor_parallel_size
global_dp_size = parallel_config . data_parallel_size
global_pp_size = parallel_config . pipeline_parallel_size
2025-07-28 14:06:20 +08:00
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
2025-12-05 10:31:49 +08:00
all_ranks = torch . arange ( world_size ) . reshape (
2025-12-12 14:41:20 +08:00
- 1 , global_dp_size * parallel_config . prefill_context_parallel_size *
global_tp_size )
2025-09-30 15:10:29 +08:00
pd_tp_ratio = get_ascend_config ( ) . pd_tp_ratio
2025-10-11 11:22:23 +08:00
pd_head_ratio = get_ascend_config ( ) . pd_head_ratio
2025-09-30 15:10:29 +08:00
global _P_TP
assert _P_TP is None , (
" distributed prefill tensor parallel group is already initialized " )
2025-10-11 11:22:23 +08:00
prefill_tensor_model_parallel_size = pd_tp_ratio
# divide alltoall groups
if pd_head_ratio > 1 and get_current_vllm_config (
) . kv_transfer_config . is_kv_producer :
num_head_replica = get_ascend_config ( ) . num_head_replica
2025-12-12 14:41:20 +08:00
remote_tp_size = global_tp_size / / pd_tp_ratio
2025-10-11 11:22:23 +08:00
if num_head_replica < = 1 :
group_ranks = all_ranks . view (
- 1 , prefill_tensor_model_parallel_size ) . unbind ( 0 )
else :
group_ranks = all_ranks . clone ( ) . view (
2025-12-12 14:41:20 +08:00
global_dp_size , - 1 ,
2025-10-11 11:22:23 +08:00
num_head_replica ) # [DP_size, num_head, num_head_replica]
group_ranks = group_ranks . permute ( 0 , 2 , 1 )
group_ranks = group_ranks . reshape (
- 1 ,
group_ranks . size ( - 1 ) ) # [DP_size * num_head_replica, num_head]
alltoall_group_size = group_ranks . size ( - 1 ) / / remote_tp_size
group_ranks = group_ranks . unsqueeze ( - 1 ) . view (
2025-12-12 14:41:20 +08:00
global_dp_size , num_head_replica , - 1 , alltoall_group_size
2025-10-11 11:22:23 +08:00
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
2025-10-30 22:20:34 +08:00
group_ranks = group_ranks . reshape ( - 1 ,
alltoall_group_size ) . unbind ( 0 )
2025-10-11 11:22:23 +08:00
group_ranks = [ x . tolist ( ) for x in group_ranks ]
local_rank = get_world_group ( ) . local_rank
num = next (
( i for i , ranks in enumerate ( group_ranks ) if local_rank in ranks ) ,
None )
_P_TP = init_model_parallel_group ( group_ranks ,
get_world_group ( ) . local_rank ,
backend ,
group_name = f " p_tp_ { num } " )
2025-09-30 15:10:29 +08:00
2025-07-28 14:06:20 +08:00
global _MC2
group_ranks = all_ranks . unbind ( 0 )
group_ranks = [ x . tolist ( ) for x in group_ranks ]
_MC2 = init_model_parallel_group ( group_ranks ,
get_world_group ( ) . local_rank ,
backend ,
group_name = " mc2 " )
2025-08-21 09:22:07 +08:00
2025-12-18 20:06:53 +08:00
# Initialize fine-grained TP process groups on Ascend for four components:
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
# 3. Embedding: The token embedding table at the input of the model (`embedding_tensor_parallel_size`)
# 4. MLP: feed-forward network in transformer blocks (`mlp_tensor_parallel_size`)
2025-12-12 14:41:20 +08:00
_group_cache = { }
def _create_or_get_group ( group_size : int ,
group_name : str ) - > GroupCoordinator :
if group_size is None :
return None
if group_size not in _group_cache :
rank_grid = torch . arange ( world_size ) . reshape (
global_pp_size , global_dp_size , global_tp_size )
num_chunks = global_dp_size / / group_size
group_ranks = [ ]
for pp_idx in range ( global_pp_size ) :
stage_ranks = rank_grid [ pp_idx ] # (dp, tp)
for chunk in range ( num_chunks ) :
for tp_idx in range ( global_tp_size ) :
group = stage_ranks [ chunk * group_size : ( chunk + 1 ) *
group_size , tp_idx ] . tolist ( )
group_ranks . append ( group )
pg = init_model_parallel_group ( group_ranks ,
get_world_group ( ) . local_rank ,
backend ,
group_name = group_name )
_group_cache [ group_size ] = pg
return _group_cache [ group_size ]
otp_size = get_ascend_config (
) . finegrained_tp_config . oproj_tensor_parallel_size
lmhead_tp_size = get_ascend_config (
) . finegrained_tp_config . lmhead_tensor_parallel_size
embedding_tp_size = get_ascend_config (
) . finegrained_tp_config . embedding_tensor_parallel_size
mlp_tp_size = get_ascend_config (
2025-12-18 20:06:53 +08:00
) . finegrained_tp_config . mlp_tensor_parallel_size
2025-12-12 14:41:20 +08:00
2025-12-18 20:06:53 +08:00
global _OTP , _LMTP , _EMBED_TP , _MLP_TP
2025-12-12 14:41:20 +08:00
if otp_size > 0 :
_OTP = _create_or_get_group ( otp_size , " otp " )
if lmhead_tp_size > 0 :
_LMTP = _create_or_get_group ( lmhead_tp_size , " lmheadtp " )
if embedding_tp_size > 0 :
_EMBED_TP = _create_or_get_group ( embedding_tp_size , " emtp " )
if mlp_tp_size > 0 :
_MLP_TP = _create_or_get_group ( mlp_tp_size , " mlptp " )
2025-08-29 11:41:21 +08:00
2025-12-11 12:43:04 +08:00
def _create_shared_weight_group ( group_name : str ) - > GroupCoordinator :
#This communication domain is used for asynchronous broadcasting, so we will create a new communication group to avoid interference
group_ranks = [ ]
for pp_idx in range ( global_pp_size ) :
group = [ ]
for dp_idx in range ( global_dp_size ) :
base = ( dp_idx * global_pp_size + pp_idx ) * global_tp_size
for i in range ( global_tp_size ) :
global_rank = base + i
group . append ( global_rank )
group_ranks . append ( group )
return init_model_parallel_group ( group_ranks ,
get_world_group ( ) . local_rank ,
backend ,
group_name = group_name )
global _SHARED_WEIGHT
# TODO: Check if the model is Deepseek V3.2 with enabled SFA CP and activated shared weights. It will then be normalized within the PCP parameters. -- clrs97
is_ds_v32 = hasattr ( vllm_config . model_config . hf_config , " index_topk " )
2025-12-15 16:21:49 +08:00
if enable_sp ( ) and is_ds_v32 and _SHARED_WEIGHT is None :
2025-12-11 12:43:04 +08:00
_SHARED_WEIGHT = _create_shared_weight_group ( " CP_shared_weight " )
2025-11-10 11:01:45 +08:00
# TODO: Extract and unify the logic across different communication group.
if flashcomm2_enable ( ) :
flashcomm2_otp_size = get_ascend_config (
) . flashcomm2_oproj_tensor_parallel_size
global_tp_size = get_tp_group ( ) . world_size
global_dp_size = get_dp_group ( ) . world_size
2025-12-01 15:56:22 +08:00
global_pp_size = get_pp_group ( ) . world_size
2025-11-10 11:01:45 +08:00
num_fc2_oproj_tensor_parallel_groups : int = ( global_tp_size / /
flashcomm2_otp_size )
global _FLASHCOMM2_OTP
global _FLASHCOMM2_ODP
_FLASHCOMM2_OTP = None
_FLASHCOMM2_ODP = get_tp_group ( )
if flashcomm2_otp_size > 1 :
otp_group_ranks = [ ]
odp_group_ranks : list [ list [ int ] ] = [
2025-12-01 15:56:22 +08:00
[ ] for _ in range ( flashcomm2_otp_size * global_dp_size *
global_pp_size )
2025-11-10 11:01:45 +08:00
]
for dp_group_index in range ( global_dp_size ) :
2025-12-01 15:56:22 +08:00
for pp_group_index in range ( global_pp_size ) :
dp_pp_serial_index = dp_group_index * global_pp_size + pp_group_index
tp_base_rank = dp_pp_serial_index * global_tp_size
odp_base_index = dp_pp_serial_index * flashcomm2_otp_size
for i in range ( num_fc2_oproj_tensor_parallel_groups ) :
ranks = [ ]
for j in range ( flashcomm2_otp_size ) :
tp_local_rank = i + j * num_fc2_oproj_tensor_parallel_groups
assert tp_local_rank < global_tp_size
global_rank = tp_base_rank + tp_local_rank
ranks . append ( global_rank )
odp_group_index = odp_base_index + j
odp_group_ranks [ odp_group_index ] . append (
global_rank )
otp_group_ranks . append ( ranks )
2025-11-10 11:01:45 +08:00
_FLASHCOMM2_OTP = init_model_parallel_group (
otp_group_ranks ,
get_world_group ( ) . local_rank ,
backend ,
group_name = " flashcomm2_otp " )
_FLASHCOMM2_ODP = init_model_parallel_group (
odp_group_ranks ,
get_world_group ( ) . local_rank ,
backend ,
group_name = " flashcomm2_odp " )
2025-12-11 12:43:04 +08:00
# Create shared weight group for flashcomm2 oproj
if flashcomm2_o_shared_enabled ( ) :
assert flashcomm2_otp_size == 1 , " flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1 "
2025-12-15 16:21:49 +08:00
if _SHARED_WEIGHT is None :
_SHARED_WEIGHT = _create_shared_weight_group (
" flashcomm2_o_shared " )
2025-12-06 19:46:41 +08:00
2025-12-14 09:34:13 +08:00
if get_ascend_config ( ) . multistream_overlap_gate :
global _FC3_QUANT_X
group_ranks = all_ranks . unbind ( 0 )
group_ranks = [ x . tolist ( ) for x in group_ranks ]
_FC3_QUANT_X = init_model_parallel_group ( group_ranks ,
get_world_group ( ) . local_rank ,
backend ,
group_name = " fc3_quant_x " )
2025-08-21 09:22:07 +08:00
2025-12-12 14:41:20 +08:00
def model_parallel_initialized ( ) :
return ( _MC2 is not None )
def get_mc2_group ( ) - > GroupCoordinator :
assert _MC2 is not None , ( " mc2 group is not initialized " )
return _MC2
def get_mlp_tp_group ( ) - > GroupCoordinator :
assert _MLP_TP is not None , ( " mlp group is not initialized " )
return _MLP_TP
def get_otp_group ( ) - > GroupCoordinator :
assert _OTP is not None , (
" output tensor parallel group is not initialized " )
return _OTP
def get_lmhead_tp_group ( ) - > GroupCoordinator :
assert _LMTP is not None , (
" lm head tensor parallel group is not initialized " )
return _LMTP
def get_embed_tp_group ( ) - > GroupCoordinator :
assert _EMBED_TP is not None , ( " emtp group is not initialized " )
return _EMBED_TP
def get_flashcomm2_otp_group ( ) - > GroupCoordinator :
return _FLASHCOMM2_OTP
def get_flashcomm2_odp_group ( ) - > GroupCoordinator :
assert _FLASHCOMM2_ODP is not None , (
" output data parallel group for flashcomm2 is not initialized " )
return _FLASHCOMM2_ODP
def get_shared_weight_group ( ) - > GroupCoordinator :
assert _SHARED_WEIGHT is not None , (
" output shared weight parallel group for flashcomm2 is not initialized "
)
return _SHARED_WEIGHT
2025-08-21 09:22:07 +08:00
2025-12-12 14:41:20 +08:00
def get_p_tp_group ( ) - > GroupCoordinator :
assert _P_TP is not None , (
" distributed prefill tensor parallel group is not initialized " )
return _P_TP
2025-07-28 14:06:20 +08:00
2025-12-14 09:34:13 +08:00
def get_fc3_quant_x_group ( ) - > GroupCoordinator :
assert _FC3_QUANT_X is not None , ( " fc3 quant x group is not initialized " )
return _FC3_QUANT_X
2025-07-28 14:06:20 +08:00
def destroy_ascend_model_parallel ( ) :
global _MC2
if _MC2 :
_MC2 . destroy ( )
_MC2 = None
2025-08-21 09:22:07 +08:00
global _MLP_TP
if _MLP_TP :
_MLP_TP . destroy ( )
_MLP_TP = None
2025-08-29 11:41:21 +08:00
global _LMTP
if _LMTP :
_LMTP . destroy ( )
_LMTP = None
2025-09-07 10:31:32 +08:00
2025-12-12 14:41:20 +08:00
global _EMBED_TP
if _EMBED_TP :
_EMBED_TP . destroy ( )
_EMBED_TP = None
2025-09-07 10:31:32 +08:00
global _OTP
if _OTP :
_OTP . destroy ( )
_OTP = None
2025-09-30 15:10:29 +08:00
global _P_TP
if _P_TP :
_P_TP . destroy ( )
_P_TP = None
2025-11-10 11:01:45 +08:00
global _FLASHCOMM2_OTP
if _FLASHCOMM2_OTP and get_ascend_config (
) . flashcomm2_oproj_tensor_parallel_size != 1 :
_FLASHCOMM2_OTP . destroy ( )
_FLASHCOMM2_OTP = None
global _FLASHCOMM2_ODP
if _FLASHCOMM2_ODP and get_ascend_config (
) . flashcomm2_oproj_tensor_parallel_size != 1 :
_FLASHCOMM2_ODP . destroy ( )
_FLASHCOMM2_ODP = None
2025-12-06 19:46:41 +08:00
global _SHARED_WEIGHT
if _SHARED_WEIGHT :
_SHARED_WEIGHT . destroy ( )
_SHARED_WEIGHT = None
2025-12-14 09:34:13 +08:00
global _FC3_QUANT_X
if _FC3_QUANT_X :
_FC3_QUANT_X . destroy ( )
_FC3_QUANT_X = None