Optimize DP attn scheduling for speculative decoding (#7285)
This commit is contained in:
@@ -1399,29 +1399,6 @@ class Scheduler(
|
|||||||
self.metrics_collector.log_stats(self.stats)
|
self.metrics_collector.log_stats(self.stats)
|
||||||
self._publish_kv_events()
|
self._publish_kv_events()
|
||||||
|
|
||||||
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
|
|
||||||
"""Coordinate the DP attention batch."""
|
|
||||||
|
|
||||||
local_info = torch.tensor(
|
|
||||||
[
|
|
||||||
(new_batch is not None),
|
|
||||||
],
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
global_info = torch.empty(
|
|
||||||
(self.server_args.dp_size, self.attn_tp_size, 1),
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
torch.distributed.all_gather_into_tensor(
|
|
||||||
global_info.flatten(),
|
|
||||||
local_info,
|
|
||||||
group=self.tp_cpu_group,
|
|
||||||
)
|
|
||||||
any_new_batch = any(
|
|
||||||
global_info[:, 0, 0].tolist()
|
|
||||||
) # Any DP worker has forward batch
|
|
||||||
return any_new_batch
|
|
||||||
|
|
||||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||||
# Merge the prefill batch into the running batch
|
# Merge the prefill batch into the running batch
|
||||||
chunked_req_to_exclude = set()
|
chunked_req_to_exclude = set()
|
||||||
@@ -1456,13 +1433,15 @@ class Scheduler(
|
|||||||
|
|
||||||
new_batch = self.get_new_batch_prefill()
|
new_batch = self.get_new_batch_prefill()
|
||||||
|
|
||||||
# TODO(ch-wan): minor refactor is needed here to improve readability
|
need_dp_attn_preparation = require_mlp_sync(self.server_args)
|
||||||
any_new_batch = (
|
|
||||||
self.server_args.enable_dp_attention
|
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
||||||
and not self.spec_algorithm.is_none()
|
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
||||||
and self.coordinate_spec_dp_attn_batch(new_batch)
|
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
||||||
)
|
new_batch, _ = self.prepare_dp_attn_batch(new_batch)
|
||||||
if new_batch is not None or any_new_batch:
|
need_dp_attn_preparation = new_batch is None
|
||||||
|
|
||||||
|
if new_batch is not None:
|
||||||
# Run prefill first if possible
|
# Run prefill first if possible
|
||||||
ret = new_batch
|
ret = new_batch
|
||||||
else:
|
else:
|
||||||
@@ -1473,8 +1452,9 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
ret = None
|
ret = None
|
||||||
|
|
||||||
if require_mlp_sync(self.server_args):
|
# Handle DP attention
|
||||||
ret, _ = self.prepare_mlp_sync_batch(ret)
|
if need_dp_attn_preparation:
|
||||||
|
ret, _ = self.prepare_dp_attn_batch(ret)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user