Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)

This commit is contained in:
Trevor Morris
2025-08-15 22:08:11 -07:00
committed by GitHub
parent 87dab54824
commit eff4eb3fdd
16 changed files with 360 additions and 52 deletions

View File

@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
_device: torch.device
_global_dp_buffer_len: int
_local_dp_buffer_len: int
_global_num_tokens: Optional[List[int]]
@classmethod
def set_metadata(cls, hidden_size: int, dtype: torch.dtype, device: torch.device):
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
cls._device = device
@classmethod
def set_dp_buffer_len(cls, global_dp_buffer_len: int, local_dp_buffer_len: int):
def set_dp_buffer_len(
cls,
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
cls._global_dp_buffer_len = global_dp_buffer_len
cls._local_dp_buffer_len = local_dp_buffer_len
cls._global_num_tokens = global_num_tokens
@classmethod
def get_global_dp_buffer(cls) -> torch.Tensor:
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
def get_local_dp_buffer_len(cls) -> int:
return cls._local_dp_buffer_len
@classmethod
def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens
def set_dp_buffer_len(global_dp_buffer_len: int, local_dp_buffer_len: int):
def set_dp_buffer_len(
global_dp_buffer_len: int,
local_dp_buffer_len: int,
global_num_tokens: Optional[List[int]] = None,
):
_DpGatheredBufferWrapper.set_dp_buffer_len(
global_dp_buffer_len, local_dp_buffer_len
global_dp_buffer_len, local_dp_buffer_len, global_num_tokens
)
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
return _DpGatheredBufferWrapper.get_local_dp_buffer_len()
def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0