DP Enhancement (#8280)
This commit is contained in:
@@ -3,7 +3,8 @@ from __future__ import annotations
|
||||
import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, List
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
|
||||
_LOCAL_ATTN_DP_RANK = None
|
||||
|
||||
|
||||
class DPPaddingMode(IntEnum):
|
||||
|
||||
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
||||
MAX_LEN = auto()
|
||||
# Padding tokens to sum length and then gather tokens using `all_reduce`
|
||||
SUM_LEN = auto()
|
||||
|
||||
def is_max_len(self):
|
||||
return self == DPPaddingMode.MAX_LEN
|
||||
|
||||
def is_sum_len(self):
|
||||
return self == DPPaddingMode.SUM_LEN
|
||||
|
||||
@classmethod
|
||||
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
|
||||
# we choose the mode that minimizes the communication cost
|
||||
max_len = max(global_num_tokens)
|
||||
sum_len = sum(global_num_tokens)
|
||||
if sum_len * 2 > max_len * get_attention_dp_size():
|
||||
return cls.MAX_LEN
|
||||
else:
|
||||
return cls.SUM_LEN
|
||||
|
||||
@classmethod
|
||||
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
|
||||
return cls.MAX_LEN
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
@@ -162,7 +191,7 @@ def disable_dp_size():
|
||||
_ATTN_DP_SIZE = old_dp_size
|
||||
|
||||
|
||||
def get_dp_local_info(forward_batch: ForwardBatch):
|
||||
def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
||||
dp_rank = get_attention_dp_rank()
|
||||
|
||||
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
||||
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
||||
|
||||
|
||||
def _dp_gather(
|
||||
def _dp_gather_via_all_reduce(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
@@ -238,13 +267,6 @@ def _dp_gather(
|
||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||
), "aliasing between global_tokens and local_tokens not allowed"
|
||||
|
||||
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
||||
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
||||
# actual size of the accepted tokens.
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||
|
||||
memcpy_triton(
|
||||
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
||||
)
|
||||
@@ -263,6 +285,38 @@ def _dp_gather(
|
||||
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
||||
|
||||
|
||||
def _dp_gather_via_all_gather(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
is_partial: bool,
|
||||
):
|
||||
if not is_partial:
|
||||
if get_attention_tp_rank() != 0:
|
||||
local_tokens.fill_(0)
|
||||
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
|
||||
get_attention_tp_rank()
|
||||
]
|
||||
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
|
||||
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
|
||||
|
||||
|
||||
def _dp_gather(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
is_partial: bool,
|
||||
):
|
||||
if forward_batch.dp_padding_mode.is_max_len():
|
||||
_dp_gather_via_all_gather(
|
||||
global_tokens, local_tokens, forward_batch, is_partial
|
||||
)
|
||||
else:
|
||||
_dp_gather_via_all_reduce(
|
||||
global_tokens, local_tokens, forward_batch, is_partial
|
||||
)
|
||||
|
||||
|
||||
def dp_gather_partial(
|
||||
global_tokens: torch.Tensor,
|
||||
local_tokens: torch.Tensor,
|
||||
@@ -296,24 +350,18 @@ def dp_scatter(
|
||||
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
||||
), "aliasing between local_tokens and global_tokens not allowed"
|
||||
|
||||
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
|
||||
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
|
||||
# actual size of the accepted tokens.
|
||||
if forward_batch.forward_mode.is_draft_extend():
|
||||
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
||||
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
||||
|
||||
memcpy_triton(
|
||||
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
||||
)
|
||||
|
||||
|
||||
def attn_tp_reduce_scatter(
|
||||
output: torch.Tensor,
|
||||
input_list: List[torch.Tensor],
|
||||
):
|
||||
return get_attention_tp_group().reduce_scatter(output, input_list)
|
||||
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||
return get_attention_tp_group().reduce_scatter_tensor(output, input)
|
||||
|
||||
|
||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
|
||||
def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather_into_tensor(output, input)
|
||||
|
||||
|
||||
def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
|
||||
return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)
|
||||
|
||||
Reference in New Issue
Block a user