Return more infos for computing average acceptance length (#3152)

This commit is contained in:
Lianmin Zheng
2025-01-26 04:51:54 -08:00
committed by GitHub
parent 7e0976133c
commit 1dda8c5e4c
10 changed files with 97 additions and 15 deletions

View File

@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
],
tp_rank,
torch.distributed.get_backend(tp_group.device_group),
False,
SYNC_TOKEN_IDS_ACROSS_TP,
False,
False,
False,