[PP] Add pipeline parallelism (#5724)
This commit is contained in:
@@ -43,6 +43,7 @@ def initialize_dp_attention(
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
dp_size: int,
|
||||
pp_size: int,
|
||||
):
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||
|
||||
@@ -53,17 +54,19 @@ def initialize_dp_attention(
|
||||
)
|
||||
|
||||
if enable_dp_attention:
|
||||
local_rank = tp_rank % (tp_size // dp_size)
|
||||
_DP_SIZE = dp_size
|
||||
else:
|
||||
local_rank = tp_rank
|
||||
_DP_SIZE = 1
|
||||
|
||||
tp_group = get_tp_group()
|
||||
_ATTN_TP_GROUP = GroupCoordinator(
|
||||
[
|
||||
list(range(head, head + _ATTN_TP_SIZE))
|
||||
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
||||
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
||||
],
|
||||
tp_group.local_rank,
|
||||
local_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
False,
|
||||
|
||||
Reference in New Issue
Block a user