Use device_group for all_gather when disabling overlap scheduling (#8001)
This commit is contained in:
@@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|||||||
batch,
|
batch,
|
||||||
dp_size=model_runner.server_args.dp_size,
|
dp_size=model_runner.server_args.dp_size,
|
||||||
attn_tp_size=1,
|
attn_tp_size=1,
|
||||||
tp_cpu_group=model_runner.tp_group.cpu_group,
|
tp_group=model_runner.tp_group,
|
||||||
get_idle_batch=None,
|
get_idle_batch=None,
|
||||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||||
spec_algorithm=SpeculativeAlgorithm.NONE,
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
||||||
speculative_num_draft_tokens=None,
|
speculative_num_draft_tokens=None,
|
||||||
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
||||||
|
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1945,7 +1945,7 @@ class Scheduler(
|
|||||||
local_batch,
|
local_batch,
|
||||||
dp_size=self.server_args.dp_size,
|
dp_size=self.server_args.dp_size,
|
||||||
attn_tp_size=self.attn_tp_size,
|
attn_tp_size=self.attn_tp_size,
|
||||||
tp_cpu_group=self.tp_cpu_group,
|
tp_group=self.tp_group,
|
||||||
get_idle_batch=self.get_idle_batch,
|
get_idle_batch=self.get_idle_batch,
|
||||||
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
disable_cuda_graph=self.server_args.disable_cuda_graph,
|
||||||
spec_algorithm=self.spec_algorithm,
|
spec_algorithm=self.spec_algorithm,
|
||||||
@@ -1954,6 +1954,7 @@ class Scheduler(
|
|||||||
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
enable_deepep_moe=self.server_args.enable_deepep_moe,
|
||||||
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
|
||||||
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
|
||||||
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1961,7 +1962,7 @@ class Scheduler(
|
|||||||
local_batch: ScheduleBatch,
|
local_batch: ScheduleBatch,
|
||||||
dp_size,
|
dp_size,
|
||||||
attn_tp_size: int,
|
attn_tp_size: int,
|
||||||
tp_cpu_group,
|
tp_group,
|
||||||
get_idle_batch,
|
get_idle_batch,
|
||||||
disable_cuda_graph: bool,
|
disable_cuda_graph: bool,
|
||||||
spec_algorithm,
|
spec_algorithm,
|
||||||
@@ -1970,6 +1971,7 @@ class Scheduler(
|
|||||||
enable_deepep_moe: bool,
|
enable_deepep_moe: bool,
|
||||||
deepep_mode: DeepEPMode,
|
deepep_mode: DeepEPMode,
|
||||||
require_mlp_tp_gather: bool,
|
require_mlp_tp_gather: bool,
|
||||||
|
disable_overlap_schedule: bool,
|
||||||
):
|
):
|
||||||
# Check if other DP workers have running batches
|
# Check if other DP workers have running batches
|
||||||
if local_batch is None:
|
if local_batch is None:
|
||||||
@@ -2000,6 +2002,12 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
tbo_preparer = TboDPAttentionPreparer()
|
tbo_preparer = TboDPAttentionPreparer()
|
||||||
|
if disable_overlap_schedule:
|
||||||
|
group = tp_group.device_group
|
||||||
|
device = tp_group.device
|
||||||
|
else:
|
||||||
|
group = tp_group.cpu_group
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
local_info = torch.tensor(
|
local_info = torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -2015,15 +2023,17 @@ class Scheduler(
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
global_info = torch.empty(
|
global_info = torch.empty(
|
||||||
(dp_size, attn_tp_size, 6),
|
(dp_size, attn_tp_size, 6),
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
torch.distributed.all_gather_into_tensor(
|
torch.distributed.all_gather_into_tensor(
|
||||||
global_info.flatten(),
|
global_info.flatten(),
|
||||||
local_info,
|
local_info,
|
||||||
group=tp_cpu_group,
|
group=group,
|
||||||
)
|
)
|
||||||
global_num_tokens = global_info[:, 0, 0].tolist()
|
global_num_tokens = global_info[:, 0, 0].tolist()
|
||||||
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
can_cuda_graph = min(global_info[:, 0, 1].tolist())
|
||||||
|
|||||||
Reference in New Issue
Block a user