DP Attention with Auto DeepEP Dispatch (#7222)

This commit is contained in:
Cheng Wan
2025-07-05 01:54:24 -07:00
committed by GitHub
parent 75354d9ae9
commit 8fc910db03
13 changed files with 136 additions and 90 deletions

View File

@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For DP attention
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
is_extend_in_batch: bool = False
can_run_dp_cuda_graph: bool = False
is_extend_in_batch: bool = False
tbo_split_seq_index: Optional[int] = None
@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_ids_logprobs=self.token_ids_logprobs,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
is_extend_in_batch=self.is_extend_in_batch,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
# For DP attention
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
is_extend_in_batch: bool
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]

View File

@@ -1490,7 +1490,7 @@ class Scheduler(
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
new_batch = self.prepare_mlp_sync_batch(new_batch)
need_dp_attn_preparation = new_batch is None
if new_batch is not None:
@@ -1506,7 +1506,7 @@ class Scheduler(
# Handle DP attention
if need_dp_attn_preparation:
ret, _ = self.prepare_mlp_sync_batch(ret)
ret = self.prepare_mlp_sync_batch(ret)
return ret
@@ -1923,8 +1923,7 @@ class Scheduler(
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch)
return local_batch
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(