[PP] Fix init_memory_pool desync & add PP for mixtral (#6223)
This commit is contained in:
@@ -719,7 +719,7 @@ class Scheduler(
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# send the outputs to the next step
|
||||
# (last rank) send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
@@ -759,18 +759,18 @@ class Scheduler(
|
||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
# carry the outputs to the next stage
|
||||
# (not last rank)
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
# carry the outputs to the next stage
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
# send out reqs to the next stage
|
||||
dp_offset = self.dp_rank * self.attn_tp_size
|
||||
if self.attn_tp_rank == 0:
|
||||
|
||||
Reference in New Issue
Block a user