[DP Attention] Refactor: adding some utility functions (#9136)
This commit is contained in:
@@ -4,7 +4,7 @@ import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -18,21 +18,26 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
_ATTN_TP_GROUP = None
|
||||
_ATTN_TP_RANK = None
|
||||
_ATTN_TP_SIZE = None
|
||||
_ATTN_DP_RANK = None
|
||||
_ATTN_DP_SIZE = None
|
||||
_LOCAL_ATTN_DP_SIZE = None
|
||||
_LOCAL_ATTN_DP_RANK = None
|
||||
_ATTN_TP_GROUP: Optional[GroupCoordinator] = None
|
||||
_ATTN_TP_RANK: Optional[int] = None
|
||||
_ATTN_TP_SIZE: Optional[int] = None
|
||||
_ATTN_DP_RANK: Optional[int] = None
|
||||
_ATTN_DP_SIZE: Optional[int] = None
|
||||
_LOCAL_ATTN_DP_SIZE: Optional[int] = None
|
||||
_LOCAL_ATTN_DP_RANK: Optional[int] = None
|
||||
_ENABLE_DP_ATTENTION_FLAG: bool = False
|
||||
|
||||
|
||||
class DPPaddingMode(IntEnum):
|
||||
class DpPaddingMode(IntEnum):
|
||||
|
||||
# Padding tokens to max length and then gather tokens using `all_gather_into_tensor`
|
||||
MAX_LEN = auto()
|
||||
@@ -40,13 +45,13 @@ class DPPaddingMode(IntEnum):
|
||||
SUM_LEN = auto()
|
||||
|
||||
def is_max_len(self):
|
||||
return self == DPPaddingMode.MAX_LEN
|
||||
return self == DpPaddingMode.MAX_LEN
|
||||
|
||||
def is_sum_len(self):
|
||||
return self == DPPaddingMode.SUM_LEN
|
||||
return self == DpPaddingMode.SUM_LEN
|
||||
|
||||
@classmethod
|
||||
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DPPaddingMode:
|
||||
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
|
||||
# we choose the mode that minimizes the communication cost
|
||||
max_len = max(global_num_tokens)
|
||||
sum_len = sum(global_num_tokens)
|
||||
@@ -56,10 +61,76 @@ class DPPaddingMode(IntEnum):
|
||||
return cls.SUM_LEN
|
||||
|
||||
@classmethod
|
||||
def get_default_mode_in_cuda_graph(cls) -> DPPaddingMode:
|
||||
def get_default_mode_in_cuda_graph(cls) -> DpPaddingMode:
|
||||
return cls.MAX_LEN
|
||||
|
||||
|
||||
class _DpGatheredBufferWrapper:
|
||||
|
||||
_hidden_size: int
|
||||
_dtype: torch.dtype
|
||||
_device: torch.device
|
||||
_global_dp_buffer_len: int
|
||||
_local_dp_buffer_len: int
|
||||
|
||||
@classmethod
|
||||
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
||||
cls._hidden_size = hidden_size
|
||||
cls._dtype = dtype
|
||||
cls._device = device
|
||||
|
||||
@classmethod
|
||||
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
cls._global_dp_buffer_len = global_dp_buffer_len
|
||||
cls._local_dp_buffer_len = local_dp_buffer_len
|
||||
|
||||
@classmethod
|
||||
def get_global_dp_buffer(cls) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(cls._global_dp_buffer_len, cls._hidden_size),
|
||||
dtype=cls._dtype,
|
||||
device=cls._device,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_local_dp_buffer(cls) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(cls._local_dp_buffer_len, cls._hidden_size),
|
||||
dtype=cls._dtype,
|
||||
device=cls._device,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_global_dp_buffer_len(cls) -> int:
|
||||
return cls._global_dp_buffer_len
|
||||
|
||||
@classmethod
|
||||
def get_local_dp_buffer_len(cls) -> int:
|
||||
return cls._local_dp_buffer_len
|
||||
|
||||
|
||||
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
||||
global_dp_buffer_len, local_dp_buffer_len
|
||||
)
|
||||
|
||||
|
||||
def get_global_dp_buffer() -> torch.Tensor:
|
||||
return _DpGatheredBufferWrapper.get_global_dp_buffer()
|
||||
|
||||
|
||||
def get_local_dp_buffer() -> torch.Tensor:
|
||||
return _DpGatheredBufferWrapper.get_local_dp_buffer()
|
||||
|
||||
|
||||
def get_global_dp_buffer_len() -> int:
|
||||
return _DpGatheredBufferWrapper.get_global_dp_buffer_len()
|
||||
|
||||
|
||||
def get_local_dp_buffer_len() -> int:
|
||||
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
@@ -89,18 +160,24 @@ def compute_dp_attention_local_info(
|
||||
|
||||
|
||||
def initialize_dp_attention(
|
||||
enable_dp_attention: bool,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
dp_size: int,
|
||||
moe_dense_tp_size: int,
|
||||
pp_size: int,
|
||||
server_args: ServerArgs,
|
||||
model_config: ModelConfig,
|
||||
):
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
||||
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
||||
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK, _ENABLE_DP_ATTENTION_FLAG
|
||||
|
||||
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
||||
|
||||
enable_dp_attention = server_args.enable_dp_attention
|
||||
tp_size = server_args.tp_size
|
||||
dp_size = server_args.dp_size
|
||||
moe_dense_tp_size = server_args.moe_dense_tp_size
|
||||
pp_size = server_args.pp_size
|
||||
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
_ENABLE_DP_ATTENTION_FLAG = enable_dp_attention
|
||||
|
||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||
)
|
||||
@@ -135,38 +212,48 @@ def initialize_dp_attention(
|
||||
group_name="attention_tp",
|
||||
)
|
||||
|
||||
_DpGatheredBufferWrapper.set_metadata(
|
||||
hidden_size=model_config.hidden_size,
|
||||
dtype=model_config.dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
def get_attention_tp_group():
|
||||
|
||||
def is_dp_attention_enabled() -> bool:
|
||||
return _ENABLE_DP_ATTENTION_FLAG
|
||||
|
||||
|
||||
def get_attention_tp_group() -> GroupCoordinator:
|
||||
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_GROUP
|
||||
|
||||
|
||||
def get_attention_tp_rank():
|
||||
def get_attention_tp_rank() -> int:
|
||||
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_RANK
|
||||
|
||||
|
||||
def get_attention_tp_size():
|
||||
def get_attention_tp_size() -> int:
|
||||
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_SIZE
|
||||
|
||||
|
||||
def get_attention_dp_rank():
|
||||
def get_attention_dp_rank() -> int:
|
||||
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_attention_dp_size():
|
||||
def get_attention_dp_size() -> int:
|
||||
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _ATTN_DP_SIZE
|
||||
|
||||
|
||||
def get_local_attention_dp_rank():
|
||||
def get_local_attention_dp_rank() -> int:
|
||||
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_local_attention_dp_size():
|
||||
def get_local_attention_dp_size() -> int:
|
||||
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_SIZE
|
||||
|
||||
|
||||
Reference in New Issue
Block a user