[DP] fix: engine crash when decode batch is padded (#8995)

This commit is contained in:
Cheng Wan
2025-08-09 01:29:29 -07:00
committed by GitHub
parent 326a901df4
commit 5018809222
2 changed files with 33 additions and 18 deletions

View File

@@ -653,12 +653,30 @@ class ForwardBatch:
else:
num_tokens = global_num_tokens[0]
if self.forward_mode.is_decode():
setattr(self, "raw_bs", self.batch_size)
self.batch_size = num_tokens
bs = self.batch_size
if self.forward_mode.is_decode():
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
setattr(self, "_original_forward_mode", self.forward_mode)
self.forward_mode = ForwardMode.EXTEND
self.extend_num_tokens = bs
self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
self.extend_prefix_lens = self.seq_lens - 1
self.extend_start_loc = torch.arange(
bs, dtype=torch.int32, device=self.seq_lens.device
)
self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
else:
setattr(self, "_original_batch_size", self.batch_size)
if self.spec_info is not None:
bs = self.batch_size = (
num_tokens // self.spec_info.num_tokens_per_batch
)
else:
bs = self.batch_size = num_tokens
# padding
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
@@ -689,6 +707,7 @@ class ForwardBatch:
if self.mrope_positions is not None:
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
# TODO: check if we need to pad other tensors
if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
@@ -712,7 +731,9 @@ class ForwardBatch:
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
bs = getattr(self, "raw_bs", self.batch_size)
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
bs = self.batch_size
if self.spec_info is not None:
if self.forward_mode.is_decode(): # draft