diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 95bf1514c..44c2ff132 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -408,9 +408,9 @@ class CommunicateWithAllReduceAndLayerNormFn: ): if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1: residual, local_residual = ( - forward_batch.gathered_buffer[ - : forward_batch.input_ids.shape[0] - ].clone(), + torch.empty_like( + forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]] + ), residual, ) attn_tp_all_gather_into_tensor(residual, local_residual) @@ -420,13 +420,11 @@ class CommunicateWithAllReduceAndLayerNormFn: # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size) use_layer_norm_before_gather = context.attn_tp_size == 1 - if use_layer_norm_before_gather: - residual.copy_(hidden_states) - if hidden_states.shape[0] != 0: - hidden_states = layernorm(hidden_states) - + if use_layer_norm_before_gather and hidden_states.shape[0] != 0: + residual = hidden_states + hidden_states = layernorm(hidden_states) hidden_states, local_hidden_states = ( - forward_batch.gathered_buffer, + torch.empty_like(forward_batch.gathered_buffer), hidden_states, ) dp_gather_partial(hidden_states, local_hidden_states, forward_batch) @@ -552,10 +550,6 @@ class CommunicateSummableTensorPairFn: forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) - - if hidden_states.data_ptr() is global_hidden_states.data_ptr(): - hidden_states = torch.empty_like(hidden_states) - if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead. dp_reduce_scatter_tensor(hidden_states, global_hidden_states) diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e5793a269..c019d7e3f 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -653,12 +653,30 @@ class ForwardBatch: else: num_tokens = global_num_tokens[0] - if self.forward_mode.is_decode(): - setattr(self, "raw_bs", self.batch_size) - self.batch_size = num_tokens - bs = self.batch_size + if self.forward_mode.is_decode(): + if self.is_extend_in_batch and dp_padding_mode.is_max_len(): + setattr(self, "_original_forward_mode", self.forward_mode) + self.forward_mode = ForwardMode.EXTEND + self.extend_num_tokens = bs + self.extend_seq_lens = torch.full_like(self.seq_lens, 1) + self.extend_prefix_lens = self.seq_lens - 1 + self.extend_start_loc = torch.arange( + bs, dtype=torch.int32, device=self.seq_lens.device + ) + self.extend_prefix_lens_cpu = self.extend_prefix_lens.cpu() + self.extend_seq_lens_cpu = self.extend_seq_lens.cpu() + self.extend_logprob_start_lens_cpu = self.extend_prefix_lens_cpu + else: + setattr(self, "_original_batch_size", self.batch_size) + if self.spec_info is not None: + bs = self.batch_size = ( + num_tokens // self.spec_info.num_tokens_per_batch + ) + else: + bs = self.batch_size = num_tokens + # padding self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens) self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs) @@ -689,6 +707,7 @@ class ForwardBatch: if self.mrope_positions is not None: self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs) + # TODO: check if we need to pad other tensors if self.extend_seq_lens is not None: self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) @@ -712,7 +731,9 @@ class ForwardBatch: def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): - bs = getattr(self, "raw_bs", self.batch_size) + self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode) + self.batch_size = getattr(self, "_original_batch_size", self.batch_size) + bs = self.batch_size if self.spec_info is not None: if self.forward_mode.is_decode(): # draft