[PP] Fix init_memory_pool desync & add PP for mixtral (#6223)

This commit is contained in:
Ying Sheng
2025-05-12 12:38:09 -07:00
committed by GitHub
parent 12319a6787
commit bad7c26fdc
8 changed files with 179 additions and 47 deletions

View File

@@ -719,7 +719,7 @@ class Scheduler(
server_is_idle = False
result = self.run_batch(self.cur_batch)
# send the outputs to the next step
# (last rank) send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
@@ -759,18 +759,18 @@ class Scheduler(
self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id]
# carry the outputs to the next stage
# (not last rank)
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
# carry the outputs to the next stage
# send the outputs from the last round to let the next stage worker run post processing
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if not self.pp_group.is_last_rank:
# send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0: