Fix two issues related to --moe-dense-tp-size=1 (#5657)

Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
Co-authored-by: 颉沆 <xiehang.lsy@alibaba-inc.com>
This commit is contained in:
Cheng Wan
2025-05-13 02:51:39 -04:00
committed by GitHub
parent 1ab14c4c5c
commit b2e95f62b4
6 changed files with 119 additions and 45 deletions

View File

@@ -24,8 +24,10 @@ if TYPE_CHECKING:
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
_ATTN_TP_SIZE = None
_DP_RANK = None
_DP_SIZE = None
_ATTN_DP_RANK = None
_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
return tp_rank, tp_size, 0
attn_tp_size = tp_size // dp_size
dp_rank = tp_rank // attn_tp_size
attn_dp_rank = tp_rank // attn_tp_size
attn_tp_rank = tp_rank % attn_tp_size
return attn_tp_rank, attn_tp_size, dp_rank
return attn_tp_rank, attn_tp_size, attn_dp_rank
def compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
):
if not enable_dp_attention:
return tp_rank, tp_size, 0
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
local_tp_rank = tp_rank % local_tp_size
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
local_attn_tp_size = local_tp_size // local_dp_size
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
def initialize_dp_attention(
@@ -43,22 +63,32 @@ def initialize_dp_attention(
tp_rank: int,
tp_size: int,
dp_size: int,
moe_dense_tp_size: int,
pp_size: int,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
)
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
_ATTN_DP_SIZE = dp_size
if moe_dense_tp_size is None:
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
else:
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
else:
local_rank = tp_rank
_DP_SIZE = 1
_ATTN_DP_SIZE = 1
_LOCAL_ATTN_DP_SIZE = 1
tp_group = get_tp_group()
_ATTN_TP_GROUP = GroupCoordinator(
@@ -93,13 +123,33 @@ def get_attention_tp_size():
def get_attention_dp_rank():
assert _DP_RANK is not None, "dp attention not initialized!"
return _DP_RANK
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
return _ATTN_DP_RANK
def get_attention_dp_size():
assert _DP_SIZE is not None, "dp attention not initialized!"
return _DP_SIZE
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
def get_local_attention_dp_rank():
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE
@contextmanager
@@ -112,19 +162,19 @@ def disable_dp_size():
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
global _ATTN_DP_SIZE
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE
_DP_SIZE = 1
old_dp_size = _ATTN_DP_SIZE
_ATTN_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
_ATTN_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
dp_rank = get_local_attention_dp_rank()
if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)