[BUG] fix local_rank in initialize_dp_attention (#7584)
This commit is contained in:
@@ -79,14 +79,12 @@ def initialize_dp_attention(
|
||||
)
|
||||
|
||||
if enable_dp_attention:
|
||||
local_rank = tp_rank % (tp_size // dp_size)
|
||||
_ATTN_DP_SIZE = dp_size
|
||||
if moe_dense_tp_size is None:
|
||||
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
||||
else:
|
||||
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
||||
else:
|
||||
local_rank = tp_rank
|
||||
_ATTN_DP_SIZE = 1
|
||||
_LOCAL_ATTN_DP_SIZE = 1
|
||||
|
||||
@@ -96,7 +94,7 @@ def initialize_dp_attention(
|
||||
list(range(head, head + _ATTN_TP_SIZE))
|
||||
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
|
||||
],
|
||||
local_rank,
|
||||
tp_group.local_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
use_pymscclpp=False,
|
||||
|
||||
Reference in New Issue
Block a user