[BUG] fix local_rank in initialize_dp_attention (#7584)
This commit is contained in:
@@ -79,14 +79,12 @@ def initialize_dp_attention(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if enable_dp_attention:
|
if enable_dp_attention:
|
||||||
local_rank = tp_rank % (tp_size // dp_size)
|
|
||||||
_ATTN_DP_SIZE = dp_size
|
_ATTN_DP_SIZE = dp_size
|
||||||
if moe_dense_tp_size is None:
|
if moe_dense_tp_size is None:
|
||||||
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
||||||
else:
|
else:
|
||||||
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
||||||
else:
|
else:
|
||||||
local_rank = tp_rank
|
|
||||||
_ATTN_DP_SIZE = 1
|
_ATTN_DP_SIZE = 1
|
||||||
_LOCAL_ATTN_DP_SIZE = 1
|
_LOCAL_ATTN_DP_SIZE = 1
|
||||||
|
|
||||||
@@ -96,7 +94,7 @@ def initialize_dp_attention(
|
|||||||
list(range(head, head + _ATTN_TP_SIZE))
|
list(range(head, head + _ATTN_TP_SIZE))
|
||||||
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
||||||
],
|
],
|
||||||
local_rank,
|
tp_group.local_rank,
|
||||||
torch.distributed.get_backend(tp_group.device_group),
|
torch.distributed.get_backend(tp_group.device_group),
|
||||||
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
||||||
use_pymscclpp=False,
|
use_pymscclpp=False,
|
||||||
|
|||||||
Reference in New Issue
Block a user