Revert "[bug fix] Ensure local token and global token buffers are pointing to different storage " (#8993)
This commit is contained in:
14
python/sglang/srt/layers/dp_attention.py
Executable file → Normal file
14
python/sglang/srt/layers/dp_attention.py
Executable file → Normal file
@@ -264,10 +264,9 @@ def _dp_gather_via_all_reduce(
|
|||||||
assert global_tokens.is_contiguous()
|
assert global_tokens.is_contiguous()
|
||||||
|
|
||||||
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
||||||
if local_tokens.untyped_storage() is global_tokens.untyped_storage():
|
assert (
|
||||||
# dp_gather is an in-place operation and requires input and output tensors to not be aliased.
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||||
# so we create a separate buffer if they share the same storage.
|
), "aliasing between global_tokens and local_tokens not allowed"
|
||||||
global_tokens = torch.empty_like(global_tokens)
|
|
||||||
|
|
||||||
memcpy_triton(
|
memcpy_triton(
|
||||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||||
@@ -348,10 +347,9 @@ def dp_scatter(
|
|||||||
assert local_tokens.is_contiguous()
|
assert local_tokens.is_contiguous()
|
||||||
assert global_tokens.is_contiguous()
|
assert global_tokens.is_contiguous()
|
||||||
if local_tokens.shape[0] > 0:
|
if local_tokens.shape[0] > 0:
|
||||||
if local_tokens.untyped_storage() is global_tokens.untyped_storage():
|
assert (
|
||||||
# dp_scatter is an in-place operation and requires input and output tensors to not be aliased.
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||||
# so we create a separate buffer if they share the same storage.
|
), "aliasing between local_tokens and global_tokens not allowed"
|
||||||
local_tokens = torch.empty_like(local_tokens)
|
|
||||||
|
|
||||||
memcpy_triton(
|
memcpy_triton(
|
||||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||||
|
|||||||
Reference in New Issue
Block a user