Optimize DP attn scheduling for speculative decoding (#7285)
This commit is contained in:
@@ -1399,29 +1399,6 @@ class Scheduler(
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._publish_kv_events()
|
||||
|
||||
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
|
||||
"""Coordinate the DP attention batch."""
|
||||
|
||||
local_info = torch.tensor(
|
||||
[
|
||||
(new_batch is not None),
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
global_info = torch.empty(
|
||||
(self.server_args.dp_size, self.attn_tp_size, 1),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_info.flatten(),
|
||||
local_info,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
any_new_batch = any(
|
||||
global_info[:, 0, 0].tolist()
|
||||
) # Any DP worker has forward batch
|
||||
return any_new_batch
|
||||
|
||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||
# Merge the prefill batch into the running batch
|
||||
chunked_req_to_exclude = set()
|
||||
@@ -1456,13 +1433,15 @@ class Scheduler(
|
||||
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
|
||||
# TODO(ch-wan): minor refactor is needed here to improve readability
|
||||
any_new_batch = (
|
||||
self.server_args.enable_dp_attention
|
||||
and not self.spec_algorithm.is_none()
|
||||
and self.coordinate_spec_dp_attn_batch(new_batch)
|
||||
)
|
||||
if new_batch is not None or any_new_batch:
|
||||
need_dp_attn_preparation = require_mlp_sync(self.server_args)
|
||||
|
||||
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
||||
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
||||
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
||||
new_batch, _ = self.prepare_dp_attn_batch(new_batch)
|
||||
need_dp_attn_preparation = new_batch is None
|
||||
|
||||
if new_batch is not None:
|
||||
# Run prefill first if possible
|
||||
ret = new_batch
|
||||
else:
|
||||
@@ -1473,8 +1452,9 @@ class Scheduler(
|
||||
else:
|
||||
ret = None
|
||||
|
||||
if require_mlp_sync(self.server_args):
|
||||
ret, _ = self.prepare_mlp_sync_batch(ret)
|
||||
# Handle DP attention
|
||||
if need_dp_attn_preparation:
|
||||
ret, _ = self.prepare_dp_attn_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
Reference in New Issue
Block a user