[DP] fix: engine crash when decode batch is padded (#8995)
This commit is contained in:
@@ -653,12 +653,30 @@ class ForwardBatch:
|
||||
else:
|
||||
num_tokens = global_num_tokens[0]
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
setattr(self, "raw_bs", self.batch_size)
|
||||
self.batch_size = num_tokens
|
||||
|
||||
bs = self.batch_size
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
|
||||
setattr(self, "_original_forward_mode", self.forward_mode)
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
self.extend_num_tokens = bs
|
||||
self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
|
||||
self.extend_prefix_lens = self.seq_lens - 1
|
||||
self.extend_start_loc = torch.arange(
|
||||
bs, dtype=torch.int32, device=self.seq_lens.device
|
||||
)
|
||||
self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
|
||||
self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
|
||||
self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
|
||||
else:
|
||||
setattr(self, "_original_batch_size", self.batch_size)
|
||||
if self.spec_info is not None:
|
||||
bs = self.batch_size = (
|
||||
num_tokens // self.spec_info.num_tokens_per_batch
|
||||
)
|
||||
else:
|
||||
bs = self.batch_size = num_tokens
|
||||
|
||||
# padding
|
||||
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
||||
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
||||
@@ -689,6 +707,7 @@ class ForwardBatch:
|
||||
if self.mrope_positions is not None:
|
||||
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
|
||||
|
||||
# TODO: check if we need to pad other tensors
|
||||
if self.extend_seq_lens is not None:
|
||||
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
||||
|
||||
@@ -712,7 +731,9 @@ class ForwardBatch:
|
||||
|
||||
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
||||
|
||||
bs = getattr(self, "raw_bs", self.batch_size)
|
||||
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
|
||||
self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
|
||||
bs = self.batch_size
|
||||
|
||||
if self.spec_info is not None:
|
||||
if self.forward_mode.is_decode(): # draft
|
||||
|
||||
Reference in New Issue
Block a user