[BUG] fix local_rank in initialize_dp_attention (#7584)

This commit is contained in:
Sheng Qi
2025-06-28 11:01:01 +08:00
committed by GitHub
parent 2373faa317
commit cfe2edac38

View File

@@ -79,14 +79,12 @@ def initialize_dp_attention(
)
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_ATTN_DP_SIZE = dp_size
if moe_dense_tp_size is None:
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
else:
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
else:
local_rank = tp_rank
_ATTN_DP_SIZE = 1
_LOCAL_ATTN_DP_SIZE = 1
@@ -96,7 +94,7 @@ def initialize_dp_attention(
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE)
],
local_rank,
tp_group.local_rank,
torch.distributed.get_backend(tp_group.device_group),
use_pynccl=SYNC_TOKEN_IDS_ACROSS_TP,
use_pymscclpp=False,