Fix two issues related to --moe-dense-tp-size=1 (#5657)

Co-authored-by: liusy58 <liusy58@linux.alibaba.com>
Co-authored-by: 颉沆 <xiehang.lsy@alibaba-inc.com>
This commit is contained in:
Cheng Wan
2025-05-13 02:51:39 -04:00
committed by GitHub
parent 1ab14c4c5c
commit b2e95f62b4
6 changed files with 119 additions and 45 deletions

View File

@@ -207,7 +207,8 @@ class Scheduler(
self.page_size = server_args.page_size
# Distributed rank info
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
@@ -768,7 +769,7 @@ class Scheduler(
)
# send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size
dp_offset = self.attn_dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0:
point_to_point_pyobj(
recv_reqs,
@@ -815,7 +816,7 @@ class Scheduler(
recv_reqs = None
else:
if self.attn_tp_rank == 0:
dp_offset = self.dp_rank * self.attn_tp_size
dp_offset = self.attn_dp_rank * self.attn_tp_size
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
@@ -1610,6 +1611,7 @@ class Scheduler(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1622,6 +1624,7 @@ class Scheduler(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
@@ -1631,15 +1634,15 @@ class Scheduler(
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
@@ -1666,7 +1669,7 @@ class Scheduler(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
@@ -1689,8 +1692,15 @@ class Scheduler(
local_batch = get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
# Check forward mode for cuda graph
if not disable_cuda_graph:
@@ -2177,8 +2187,8 @@ class Scheduler(
def get_print_prefix(self):
prefix = ""
if self.dp_rank is not None:
prefix += f" DP{self.dp_rank}"
if self.attn_dp_rank is not None:
prefix += f" DP{self.attn_dp_rank}"
if self.server_args.tp_size > 1:
prefix += f" TP{self.tp_rank}"
if self.pp_size > 1: