Return more infos for computing average acceptance length (#3152)
This commit is contained in:
@@ -22,6 +22,8 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
||||
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||
|
||||
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||
|
||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||
)
|
||||
@@ -35,7 +37,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
],
|
||||
tp_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
False,
|
||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
|
||||
Reference in New Issue
Block a user