Support multi-node DP attention (#2925)
Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
@@ -1003,6 +1003,11 @@ class ScheduleBatch:
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
@@ -1117,7 +1122,7 @@ class ScheduleBatch:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
extend_seq_lens = self.extend_lens
|
||||
|
||||
Reference in New Issue
Block a user