feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -1350,6 +1350,29 @@ class Scheduler(
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
self._publish_kv_events()
|
||||
|
||||
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
|
||||
"""Coordinate the DP attention batch."""
|
||||
|
||||
local_info = torch.tensor(
|
||||
[
|
||||
(new_batch is not None),
|
||||
],
|
||||
dtype=torch.int64,
|
||||
)
|
||||
global_info = torch.empty(
|
||||
(self.server_args.dp_size, self.attn_tp_size, 1),
|
||||
dtype=torch.int64,
|
||||
)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_info.flatten(),
|
||||
local_info,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
any_new_batch = any(
|
||||
global_info[:, 0, 0].tolist()
|
||||
) # Any DP worker has forward batch
|
||||
return any_new_batch
|
||||
|
||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||
# Merge the prefill batch into the running batch
|
||||
chunked_req_to_exclude = set()
|
||||
@@ -1383,7 +1406,14 @@ class Scheduler(
|
||||
self.running_batch.merge_batch(self.last_batch)
|
||||
|
||||
new_batch = self.get_new_batch_prefill()
|
||||
if new_batch is not None:
|
||||
|
||||
# TODO(ch-wan): minor refactor is needed here to improve readability
|
||||
any_new_batch = (
|
||||
self.server_args.enable_dp_attention
|
||||
and not self.spec_algorithm.is_none()
|
||||
and self.coordinate_spec_dp_attn_batch(new_batch)
|
||||
)
|
||||
if new_batch is not None or any_new_batch:
|
||||
# Run prefill first if possible
|
||||
ret = new_batch
|
||||
else:
|
||||
@@ -1732,8 +1762,6 @@ class Scheduler(
|
||||
num_tokens_for_logprob = 0
|
||||
elif local_batch.forward_mode.is_decode():
|
||||
num_tokens = local_batch.batch_size()
|
||||
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
|
||||
num_tokens = num_tokens * speculative_num_draft_tokens
|
||||
num_tokens_for_logprob = num_tokens
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
@@ -1809,6 +1837,7 @@ class Scheduler(
|
||||
local_batch.global_num_tokens_for_logprob = (
|
||||
global_num_tokens_for_logprob
|
||||
)
|
||||
local_batch.is_extend_in_batch = any(is_extend_in_batch)
|
||||
local_batch.tbo_split_seq_index = tbo_split_seq_index
|
||||
local_batch.global_forward_mode = global_forward_mode
|
||||
|
||||
@@ -1816,6 +1845,7 @@ class Scheduler(
|
||||
if not disable_cuda_graph:
|
||||
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||
|
||||
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
|
||||
return local_batch, any(is_extend_in_batch)
|
||||
|
||||
def get_idle_batch(self):
|
||||
|
||||
Reference in New Issue
Block a user