Reduce computation and communication in DP attention (#4521)

This commit is contained in:
Cheng Wan
2025-03-18 16:41:36 -04:00
committed by GitHub
parent 9e0186f352
commit 3196999f63
5 changed files with 70 additions and 80 deletions

View File

@@ -53,10 +53,8 @@ def initialize_dp_attention(
)
if enable_dp_attention:
local_rank = tp_rank % (tp_size // dp_size)
_DP_SIZE = dp_size
else:
local_rank = tp_rank
_DP_SIZE = 1
tp_group = get_tp_group()
@@ -65,7 +63,7 @@ def initialize_dp_attention(
list(range(head, head + _ATTN_TP_SIZE))
for head in range(0, tp_size, _ATTN_TP_SIZE)
],
local_rank,
tp_group.local_rank,
torch.distributed.get_backend(tp_group.device_group),
SYNC_TOKEN_IDS_ACROSS_TP,
False,
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def dp_gather(
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
layer_id: Union[str, int],
is_partial: bool,
):
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
global_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0 and (
layer_id != "embedding" or get_attention_tp_rank() == 0
):
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
assert (
global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr()
@@ -216,6 +213,22 @@ def dp_gather(
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_gather_partial(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
):
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
def dp_gather_replicate(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
):
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
def dp_scatter(
local_tokens: torch.Tensor, # output
global_tokens: torch.Tensor, # input
@@ -236,16 +249,3 @@ def dp_scatter(
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
def do_logits_dp_scatter(logits: torch.Tensor):
local_logits = torch.empty(
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
dtype=logits.dtype,
device=logits.device,
)
dp_scatter(local_logits, logits, forward_batch)
return local_logits
return do_logits_dp_scatter