Fix data parallel + tensor parallel (#4499)

This commit is contained in:
Lianmin Zheng
2025-03-17 05:13:16 -07:00
committed by GitHub
parent f2ab37e500
commit 5493c3343e
6 changed files with 53 additions and 16 deletions

View File

@@ -38,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
return attn_tp_rank, attn_tp_size, dp_rank
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
def initialize_dp_attention(
enable_dp_attention: bool,
tp_rank: int,
tp_size: int,
dp_size: int,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
@@ -46,7 +51,13 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
_DP_SIZE = dp_size
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
else:
local_rank = tp_rank
_DP_SIZE = 1
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
@@ -54,7 +65,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
],
tp_rank,
local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,