Support multi-node DP attention (#2925)

Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
Lianmin Zheng
2025-01-16 11:15:00 -08:00
committed by GitHub
parent 58f3f2b840
commit 8b6ce52e92
16 changed files with 287 additions and 137 deletions

View File

@@ -1003,6 +1003,11 @@ class ScheduleBatch:
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
self.seq_lens_sum = 0
self.extend_num_tokens = 0
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
self,
self.model_config.vocab_size,
enable_overlap_schedule=self.enable_overlap,
)
def prepare_for_decode(self):
self.forward_mode = ForwardMode.DECODE
@@ -1117,7 +1122,7 @@ class ScheduleBatch:
self.spec_info.merge_batch(other.spec_info)
def get_model_worker_batch(self):
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
if self.forward_mode.is_decode_or_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else:
extend_seq_lens = self.extend_lens