[PP] Add pipeline parallelism (#5724)

This commit is contained in:
Ying Sheng
2025-04-30 18:18:07 -07:00
committed by GitHub
parent e97e57e699
commit 11383cec3c
25 changed files with 1150 additions and 308 deletions

View File

@@ -43,6 +43,7 @@ def initialize_dp_attention(
tp_rank: int,
tp_size: int,
dp_size: int,
pp_size: int,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
@@ -53,17 +54,19 @@ def initialize_dp_attention(
)
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
else:
local_rank = tp_rank
_DP_SIZE = 1
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
[
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
],
tp_group.local_rank,
local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,