DP Enhancement (#8280)

This commit is contained in:
Cheng Wan
2025-07-24 21:36:21 -07:00
committed by GitHub
parent 28d4d47280
commit c0fb25e949
20 changed files with 665 additions and 1116 deletions

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, List
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Tuple
import torch
import triton
@@ -30,6 +31,34 @@ _LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None
class DPPaddingMode(IntEnum):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN = auto()
# Padding tokens to sum length and then gather tokens using `all_reduce`
SUM_LEN = auto()
def is_max_len(self):
return self == DPPaddingMode.MAX_LEN
def is_sum_len(self):
return self == DPPaddingMode.SUM_LEN
@classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
# we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)
if sum_len * 2 > max_len * get_attention_dp_size():
return cls.MAX_LEN
else:
return cls.SUM_LEN
@classmethod
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
return cls.MAX_LEN
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
@@ -162,7 +191,7 @@ def disable_dp_size():
_ATTN_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch):
def get_dp_local_info(forward_batch: ForwardBatch) -> Tuple[torch.Tensor, torch.Tensor]:
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank = get_attention_dp_rank()
@@ -221,7 +250,7 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
def _dp_gather(
def _dp_gather_via_all_reduce(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
@@ -238,13 +267,6 @@ def _dp_gather(
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between global_tokens and local_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
)
@@ -263,6 +285,38 @@ def _dp_gather(
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def _dp_gather_via_all_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
if not is_partial:
if get_attention_tp_rank() != 0:
local_tokens.fill_(0)
scattered_local_tokens = local_tokens.tensor_split(get_attention_tp_size())[
get_attention_tp_rank()
]
get_attention_tp_group().reduce_scatter_tensor(scattered_local_tokens, local_tokens)
get_tp_group().all_gather_into_tensor(global_tokens, scattered_local_tokens)
def _dp_gather(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
forward_batch: ForwardBatch,
is_partial: bool,
):
if forward_batch.dp_padding_mode.is_max_len():
_dp_gather_via_all_gather(
global_tokens, local_tokens, forward_batch, is_partial
)
else:
_dp_gather_via_all_reduce(
global_tokens, local_tokens, forward_batch, is_partial
)
def dp_gather_partial(
global_tokens: torch.Tensor,
local_tokens: torch.Tensor,
@@ -296,24 +350,18 @@ def dp_scatter(
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
), "aliasing between local_tokens and global_tokens not allowed"
# NOTE: During draft extend, the gathered_buffer is padded to num_tokens * (speculative_num_steps + 1).
# But the size of local_tokens is total accepted tokens. We need to reduce the local_num_tokens to the
# actual size of the accepted tokens.
if forward_batch.forward_mode.is_draft_extend():
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
)
def attn_tp_reduce_scatter(
output: torch.Tensor,
input_list: List[torch.Tensor],
):
return get_attention_tp_group().reduce_scatter(output, input_list)
def attn_tp_reduce_scatter_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().reduce_scatter_tensor(output, input)
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
def attn_tp_all_gather_into_tensor(output: torch.Tensor, input: torch.Tensor):
return get_attention_tp_group().all_gather_into_tensor(output, input)
def attn_tp_all_gather(output_list: List[torch.Tensor], input: torch.Tensor):
return get_attention_tp_group().all_gather(input, output_tensor_list=output_list)