Reduce computation and communication in DP attention (#4521)
This commit is contained in:
@@ -53,10 +53,8 @@ def initialize_dp_attention(
|
||||
)
|
||||
|
||||
if enable_dp_attention:
|
||||
local_rank = tp_rank % (tp_size // dp_size)
|
||||
_DP_SIZE = dp_size
|
||||
else:
|
||||
local_rank = tp_rank
|
||||
_DP_SIZE = 1
|
||||
|
||||
tp_group = get_tp_group()
|
||||
@@ -65,7 +63,7 @@ def initialize_dp_attention(
|
||||
list(range(head, head + _ATTN_TP_SIZE))
|
||||
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
||||
],
|
||||
local_rank,
|
||||
tp_group.local_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
SYNC_TOKEN_IDS_ACROSS_TP,
|
||||
False,
|
||||
@@ -180,20 +178,19 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
||||
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
||||
|
||||
|
||||
def dp_gather(
|
||||
def _dp_gather(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_id: Union[str, int],
|
||||
is_partial: bool,
|
||||
):
|
||||
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
||||
|
||||
global_tokens.fill_(0)
|
||||
assert local_tokens.is_contiguous()
|
||||
assert global_tokens.is_contiguous()
|
||||
if local_tokens.shape[0] > 0 and (
|
||||
layer_id != "embedding" or get_attention_tp_rank() == 0
|
||||
):
|
||||
|
||||
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
||||
assert (
|
||||
global_tokens.untyped_storage().data_ptr()
|
||||
!= local_tokens.untyped_storage().data_ptr()
|
||||
@@ -216,6 +213,22 @@ def dp_gather(
|
||||
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
||||
|
||||
|
||||
def dp_gather_partial(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
|
||||
|
||||
|
||||
def dp_gather_replicate(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
|
||||
|
||||
|
||||
def dp_scatter(
|
||||
local_tokens: torch.Tensor, # output
|
||||
global_tokens: torch.Tensor, # input
|
||||
@@ -236,16 +249,3 @@ def dp_scatter(
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
|
||||
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
||||
def do_logits_dp_scatter(logits: torch.Tensor):
|
||||
local_logits = torch.empty(
|
||||
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device,
|
||||
)
|
||||
dp_scatter(local_logits, logits, forward_batch)
|
||||
return local_logits
|
||||
|
||||
return do_logits_dp_scatter
|
||||
|
||||
Reference in New Issue
Block a user