[DP Attention] Refactor: adding some utility functions (#9136)

This commit is contained in:
Cheng Wan
2025-08-13 21:08:06 -07:00
committed by GitHub
parent b3363cc1aa
commit b87aacb5c5
21 changed files with 216 additions and 159 deletions

View File

@@ -4,7 +4,7 @@ import functools
import logging
from contextlib import contextmanager
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
import triton
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.server_args import ServerArgs
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
_ATTN_TP_GROUP = None
_ATTN_TP_RANK = None
_ATTN_TP_SIZE = None
_ATTN_DP_RANK = None
_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_SIZE = None
_LOCAL_ATTN_DP_RANK = None
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
_ATTN_TP_RANK: Optional[int] = None
_ATTN_TP_SIZE: Optional[int] = None
_ATTN_DP_RANK: Optional[int] = None
_ATTN_DP_SIZE: Optional[int] = None
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
_LOCAL_ATTN_DP_RANK: Optional[int] = None
_ENABLE_DP_ATTENTION_FLAG: bool = False
class DPPaddingMode(IntEnum):
class DpPaddingMode(IntEnum):
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
MAX_LEN = auto()
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
SUM_LEN = auto()
def is_max_len(self):
return self == DPPaddingMode.MAX_LEN
return self == DpPaddingMode.MAX_LEN
def is_sum_len(self):
return self == DPPaddingMode.SUM_LEN
return self == DpPaddingMode.SUM_LEN
@classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
# we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
return cls.SUM_LEN
@classmethod
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
return cls.MAX_LEN
class _DpGatheredBufferWrapper:
_hidden_size: int
_dtype: torch.dtype
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
@classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
cls._hidden_size = hidden_size
cls._dtype = dtype
cls._device = device
@classmethod
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._global_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
@classmethod
def get_local_dp_buffer(cls) -> torch.Tensor:
return torch.empty(
(cls._local_dp_buffer_len, cls._hidden_size),
dtype=cls._dtype,
device=cls._device,
)
@classmethod
def get_global_dp_buffer_len(cls) -> int:
return cls._global_dp_buffer_len
@classmethod
def get_local_dp_buffer_len(cls) -> int:
return cls._local_dp_buffer_len
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len
)
def get_global_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_global_dp_buffer()
def get_local_dp_buffer() -> torch.Tensor:
return _DpGatheredBufferWrapper.get_local_dp_buffer()
def get_global_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
def get_local_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
def initialize_dp_attention(
enable_dp_attention: bool,
tp_rank: int,
tp_size: int,
dp_size: int,
moe_dense_tp_size: int,
pp_size: int,
server_args: ServerArgs,
model_config: ModelConfig,
):
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
enable_dp_attention = server_args.enable_dp_attention
tp_size = server_args.tp_size
dp_size = server_args.dp_size
moe_dense_tp_size = server_args.moe_dense_tp_size
pp_size = server_args.pp_size
tp_rank = get_tensor_model_parallel_rank()
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
enable_dp_attention, tp_rank, tp_size, dp_size
)
@@ -135,38 +212,48 @@ def initialize_dp_attention(
group_name="attention_tp",
)
_DpGatheredBufferWrapper.set_metadata(
hidden_size=model_config.hidden_size,
dtype=model_config.dtype,
device=torch.device("cuda"),
)
def get_attention_tp_group():
def is_dp_attention_enabled() -> bool:
return _ENABLE_DP_ATTENTION_FLAG
def get_attention_tp_group() -> GroupCoordinator:
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
return _ATTN_TP_GROUP
def get_attention_tp_rank():
def get_attention_tp_rank() -> int:
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
return _ATTN_TP_RANK
def get_attention_tp_size():
def get_attention_tp_size() -> int:
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
return _ATTN_TP_SIZE
def get_attention_dp_rank():
def get_attention_dp_rank() -> int:
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
return _ATTN_DP_RANK
def get_attention_dp_size():
def get_attention_dp_size() -> int:
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _ATTN_DP_SIZE
def get_local_attention_dp_rank():
def get_local_attention_dp_rank() -> int:
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_RANK
def get_local_attention_dp_size():
def get_local_attention_dp_size() -> int:
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
return _LOCAL_ATTN_DP_SIZE