[DP] fix: engine crash when decode batch is padded (#8995)
This commit is contained in:
@@ -408,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
):
|
||||
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
||||
residual, local_residual = (
|
||||
forward_batch.gathered_buffer[
|
||||
: forward_batch.input_ids.shape[0]
|
||||
].clone(),
|
||||
torch.empty_like(
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]]
|
||||
),
|
||||
residual,
|
||||
)
|
||||
attn_tp_all_gather_into_tensor(residual, local_residual)
|
||||
@@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
|
||||
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
||||
use_layer_norm_before_gather = context.attn_tp_size == 1
|
||||
if use_layer_norm_before_gather:
|
||||
residual.copy_(hidden_states)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = layernorm(hidden_states)
|
||||
|
||||
if use_layer_norm_before_gather and hidden_states.shape[0] != 0:
|
||||
residual = hidden_states
|
||||
hidden_states = layernorm(hidden_states)
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
torch.empty_like(forward_batch.gathered_buffer),
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
@@ -552,10 +550,6 @@ class CommunicateSummableTensorPairFn:
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
if hidden_states.data_ptr() is global_hidden_states.data_ptr():
|
||||
hidden_states = torch.empty_like(hidden_states)
|
||||
|
||||
if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len():
|
||||
# When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead.
|
||||
dp_reduce_scatter_tensor(hidden_states, global_hidden_states)
|
||||
|
||||
@@ -653,12 +653,30 @@ class ForwardBatch:
|
||||
else:
|
||||
num_tokens = global_num_tokens[0]
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
setattr(self, "raw_bs", self.batch_size)
|
||||
self.batch_size = num_tokens
|
||||
|
||||
bs = self.batch_size
|
||||
|
||||
if self.forward_mode.is_decode():
|
||||
if self.is_extend_in_batch and dp_padding_mode.is_max_len():
|
||||
setattr(self, "_original_forward_mode", self.forward_mode)
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
self.extend_num_tokens = bs
|
||||
self.extend_seq_lens = torch.full_like(self.seq_lens, 1)
|
||||
self.extend_prefix_lens = self.seq_lens - 1
|
||||
self.extend_start_loc = torch.arange(
|
||||
bs, dtype=torch.int32, device=self.seq_lens.device
|
||||
)
|
||||
self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu()
|
||||
self.extend_seq_lens_cpu = self.extend_seq_lens.cpu()
|
||||
self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu
|
||||
else:
|
||||
setattr(self, "_original_batch_size", self.batch_size)
|
||||
if self.spec_info is not None:
|
||||
bs = self.batch_size = (
|
||||
num_tokens // self.spec_info.num_tokens_per_batch
|
||||
)
|
||||
else:
|
||||
bs = self.batch_size = num_tokens
|
||||
|
||||
# padding
|
||||
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
|
||||
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
|
||||
@@ -689,6 +707,7 @@ class ForwardBatch:
|
||||
if self.mrope_positions is not None:
|
||||
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
|
||||
|
||||
# TODO: check if we need to pad other tensors
|
||||
if self.extend_seq_lens is not None:
|
||||
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
||||
|
||||
@@ -712,7 +731,9 @@ class ForwardBatch:
|
||||
|
||||
def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):
|
||||
|
||||
bs = getattr(self, "raw_bs", self.batch_size)
|
||||
self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
|
||||
self.batch_size = getattr(self, "_original_batch_size", self.batch_size)
|
||||
bs = self.batch_size
|
||||
|
||||
if self.spec_info is not None:
|
||||
if self.forward_mode.is_decode(): # draft
|
||||
|
||||
Reference in New Issue
Block a user