Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)
This commit is contained in:
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
|
||||
_device: torch.device
|
||||
_global_dp_buffer_len: int
|
||||
_local_dp_buffer_len: int
|
||||
_global_num_tokens: Optional[List[int]]
|
||||
|
||||
@classmethod
|
||||
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
|
||||
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
|
||||
cls._device = device
|
||||
|
||||
@classmethod
|
||||
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
def set_dp_buffer_len(
|
||||
cls,
|
||||
global_dp_buffer_len: int,
|
||||
local_dp_buffer_len: int,
|
||||
global_num_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
cls._global_dp_buffer_len = global_dp_buffer_len
|
||||
cls._local_dp_buffer_len = local_dp_buffer_len
|
||||
cls._global_num_tokens = global_num_tokens
|
||||
|
||||
@classmethod
|
||||
def get_global_dp_buffer(cls) -> torch.Tensor:
|
||||
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
|
||||
def get_local_dp_buffer_len(cls) -> int:
|
||||
return cls._local_dp_buffer_len
|
||||
|
||||
@classmethod
|
||||
def get_dp_global_num_tokens(cls) -> List[int]:
|
||||
return cls._global_num_tokens
|
||||
|
||||
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
|
||||
|
||||
def set_dp_buffer_len(
|
||||
global_dp_buffer_len: int,
|
||||
local_dp_buffer_len: int,
|
||||
global_num_tokens: Optional[List[int]] = None,
|
||||
):
|
||||
_DpGatheredBufferWrapper.set_dp_buffer_len(
|
||||
global_dp_buffer_len, local_dp_buffer_len
|
||||
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
|
||||
)
|
||||
|
||||
|
||||
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
|
||||
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
|
||||
|
||||
|
||||
def get_dp_global_num_tokens() -> List[int]:
|
||||
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
|
||||
Reference in New Issue
Block a user