[RL] use cpu group to prepare_mlp_sync_batch_raw when the server is offloaded (#10152)
This commit is contained in:
@@ -320,6 +320,7 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|||||||
speculative_num_draft_tokens=None,
|
speculative_num_draft_tokens=None,
|
||||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||||
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
||||||
|
offload_tags=set(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2339,6 +2339,7 @@ class Scheduler(
|
|||||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||||
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
||||||
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
||||||
|
offload_tags=self.offload_tags,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -2353,6 +2354,7 @@ class Scheduler(
|
|||||||
speculative_num_draft_tokens,
|
speculative_num_draft_tokens,
|
||||||
require_mlp_tp_gather: bool,
|
require_mlp_tp_gather: bool,
|
||||||
disable_overlap_schedule: bool,
|
disable_overlap_schedule: bool,
|
||||||
|
offload_tags: set[str],
|
||||||
):
|
):
|
||||||
# Check if other DP workers have running batches
|
# Check if other DP workers have running batches
|
||||||
if local_batch is None:
|
if local_batch is None:
|
||||||
@@ -2383,7 +2385,7 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tbo_preparer = TboDPAttentionPreparer()
|
tbo_preparer = TboDPAttentionPreparer()
|
||||||
if disable_overlap_schedule:
|
if len(offload_tags) == 0 and disable_overlap_schedule:
|
||||||
group = tp_group.device_group
|
group = tp_group.device_group
|
||||||
device = tp_group.device
|
device = tp_group.device
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user