From fd7e15b76d4c655ceaa8f8c89c1bd0e4a81a561e Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Fri, 8 Aug 2025 21:34:17 -0700 Subject: [PATCH] Revert "[bug fix] Ensure local token and global token buffers are pointing to different storage " (#8993) --- python/sglang/srt/layers/dp_attention.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) mode change 100755 => 100644 python/sglang/srt/layers/dp_attention.py diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py old mode 100755 new mode 100644 index 2e53befc6..79397cce5 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -264,10 +264,9 @@ def _dp_gather_via_all_reduce( assert global_tokens.is_contiguous() if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): - if local_tokens.untyped_storage() is global_tokens.untyped_storage(): - # dp_gather is an in-place operation and requires input and output tensors to not be aliased. - # so we create a separate buffer if they share the same storage. - global_tokens = torch.empty_like(global_tokens) + assert ( + local_tokens.untyped_storage() is not global_tokens.untyped_storage() + ), "aliasing between global_tokens and local_tokens not allowed" memcpy_triton( global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False @@ -348,10 +347,9 @@ def dp_scatter( assert local_tokens.is_contiguous() assert global_tokens.is_contiguous() if local_tokens.shape[0] > 0: - if local_tokens.untyped_storage() is global_tokens.untyped_storage(): - # dp_scatter is an in-place operation and requires input and output tensors to not be aliased. - # so we create a separate buffer if they share the same storage. - local_tokens = torch.empty_like(local_tokens) + assert ( + local_tokens.untyped_storage() is not global_tokens.untyped_storage() + ), "aliasing between local_tokens and global_tokens not allowed" memcpy_triton( local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True