diff --git a/python/sglang/srt/layers/communicator.py b/python/sglang/srt/layers/communicator.py index 4ef752d75..a497db464 100644 --- a/python/sglang/srt/layers/communicator.py +++ b/python/sglang/srt/layers/communicator.py @@ -553,6 +553,10 @@ class CommunicateSummableTensorPairFn: forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]], hidden_states, ) + + if hidden_states.data_ptr() is global_hidden_states.data_ptr(): + hidden_states = torch.empty_like(hidden_states) + if allow_reduce_scatter and forward_batch.dp_padding_mode.is_max_len(): # When using padding, all_reduce is skipped after MLP and MOE and reduce scatter is used here instead. dp_reduce_scatter_tensor(hidden_states, global_hidden_states)