DP Attention with Auto DeepEP Dispatch (#7222)
This commit is contained in:
@@ -840,6 +840,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
global_num_tokens_for_logprob: Optional[List[int]] = None
|
||||
is_extend_in_batch: bool = False
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
is_extend_in_batch: bool = False
|
||||
tbo_split_seq_index: Optional[int] = None
|
||||
@@ -1714,6 +1715,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
token_ids_logprobs=self.token_ids_logprobs,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
tbo_split_seq_index=self.tbo_split_seq_index,
|
||||
global_forward_mode=self.global_forward_mode,
|
||||
@@ -1798,6 +1800,7 @@ class ModelWorkerBatch:
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
global_num_tokens_for_logprob: Optional[List[int]]
|
||||
is_extend_in_batch: bool
|
||||
can_run_dp_cuda_graph: bool
|
||||
tbo_split_seq_index: Optional[int]
|
||||
global_forward_mode: Optional[ForwardMode]
|
||||
|
||||
@@ -1490,7 +1490,7 @@ class Scheduler(
|
||||
if need_dp_attn_preparation and not self.spec_algorithm.is_none():
|
||||
# In speculative decoding, prefill batches and decode batches cannot be processed in the same DP attention group.
|
||||
# We prepare idle batches in advance to skip preparing decode batches when there are prefill batches in the group.
|
||||
new_batch, _ = self.prepare_mlp_sync_batch(new_batch)
|
||||
new_batch = self.prepare_mlp_sync_batch(new_batch)
|
||||
need_dp_attn_preparation = new_batch is None
|
||||
|
||||
if new_batch is not None:
|
||||
@@ -1506,7 +1506,7 @@ class Scheduler(
|
||||
|
||||
# Handle DP attention
|
||||
if need_dp_attn_preparation:
|
||||
ret, _ = self.prepare_mlp_sync_batch(ret)
|
||||
ret = self.prepare_mlp_sync_batch(ret)
|
||||
|
||||
return ret
|
||||
|
||||
@@ -1923,8 +1923,7 @@ class Scheduler(
|
||||
if not disable_cuda_graph:
|
||||
local_batch.can_run_dp_cuda_graph = can_cuda_graph
|
||||
|
||||
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
|
||||
return local_batch, any(is_extend_in_batch)
|
||||
return local_batch
|
||||
|
||||
def get_idle_batch(self):
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
|
||||
Reference in New Issue
Block a user