Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)

This commit is contained in:
Trevor Morris
2025-08-15 22:08:11 -07:00
committed by GitHub
parent 87dab54824
commit eff4eb3fdd
16 changed files with 360 additions and 52 deletions

View File

@@ -148,7 +148,11 @@ class PyNcclCommunicator:
)
def all_gather(
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
self,
output_tensor: torch.Tensor,
input_tensor: torch.Tensor,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
)
if stream is None:
stream = self.stream
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
if sizes is not None:
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
dst_slice = output_tensor[split_offset : split_offset + split_size]
self.nccl.ncclBroadcast(
buffer_type(input_tensor.data_ptr()),
buffer_type(dst_slice.data_ptr()),
dst_slice.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclAllGather(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
input_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def reduce_scatter(
self,
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
input_tensor: torch.Tensor,
op: ReduceOp = ReduceOp.SUM,
stream=None,
sizes: Optional[list[int]] = None,
):
if self.disabled:
return
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
)
if stream is None:
stream = self.stream
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
if sizes is not None:
split_offset = 0
self.nccl.ncclGroupStart()
for root, split_size in enumerate(sizes):
chunk = input_tensor[split_offset : split_offset + split_size, ...]
self.nccl.ncclReduce(
buffer_type(chunk.data_ptr()),
buffer_type(output_tensor.data_ptr()),
chunk.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
root,
self.comm,
cudaStream_t(stream.cuda_stream),
)
split_offset += split_size
self.nccl.ncclGroupEnd()
else:
self.nccl.ncclReduceScatter(
buffer_type(input_tensor.data_ptr()),
buffer_type(output_tensor.data_ptr()),
output_tensor.numel(),
ncclDataTypeEnum.from_torch(input_tensor.dtype),
ncclRedOpTypeEnum.from_torch(op),
self.comm,
cudaStream_t(stream.cuda_stream),
)
def send(self, tensor: torch.Tensor, dst: int, stream=None):
if self.disabled:
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
def deregister_comm_window(self, window):
return self.nccl.ncclCommWindowDeregister(self.comm, window)
def group_start(self):
self.nccl.ncclGroupStart()
def group_end(self):
self.nccl.ncclGroupEnd()
@contextmanager
def change_state(
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None

View File

@@ -206,6 +206,26 @@ class NCCLLibrary:
cudaStream_t,
],
),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function(
"ncclReduce",
ncclResult_t,
[
buffer_type,
buffer_type,
ctypes.c_size_t,
ncclDataType_t,
ncclRedOp_t,
ctypes.c_int,
ncclComm_t,
cudaStream_t,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
@@ -278,6 +298,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
# ncclResult_t ncclGroupStart();
Function("ncclGroupStart", ncclResult_t, []),
# ncclResult_t ncclGroupEnd();
Function("ncclGroupEnd", ncclResult_t, []),
]
exported_functions_symm_mem = [
@@ -400,6 +424,28 @@ class NCCLLibrary:
)
)
def ncclReduce(
self,
sendbuff: buffer_type,
recvbuff: buffer_type,
count: int,
datatype: int,
op: int,
root: int,
comm: ncclComm_t,
stream: cudaStream_t,
) -> None:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self.NCCL_CHECK(
self._funcs["ncclReduce"](
sendbuff, recvbuff, count, datatype, op, root, comm, stream
)
)
def ncclReduceScatter(
self,
sendbuff: buffer_type,
@@ -499,6 +545,12 @@ class NCCLLibrary:
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
def ncclGroupStart(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
def ncclGroupEnd(self) -> None:
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
__all__ = [
"NCCLLibrary",

View File

@@ -583,6 +583,39 @@ class GroupCoordinator:
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
return output
def reduce_scatterv(
self,
input_: torch.Tensor,
output: Optional[torch.Tensor] = None,
sizes: Optional[List[int]] = None,
) -> torch.Tensor:
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for reduce_scatterv"
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sum(sizes)
chunk_size = sizes[self.rank_in_group]
else:
assert input_.shape[0] % world_size == 0
chunk_size = input_.shape[0] // world_size
output_shape = (chunk_size,) + input_.shape[1:]
if output is None:
output = torch.empty(
output_shape, dtype=input_.dtype, device=input_.device
)
else:
assert output.shape == output_shape
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
return output
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
pynccl_comm = self.pynccl_comm
if pynccl_comm is not None and not pynccl_comm.disabled:
@@ -673,6 +706,54 @@ class GroupCoordinator:
)
return output_tensor
def all_gatherv(
self,
input_: Union[torch.Tensor, List[torch.Tensor]],
sizes: Optional[List[int]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Supports varying sizes per rank and input tensor list.
`sizes`: a list of len(world_size) with the number of items per rank to gather.
"""
world_size = self.world_size
pynccl_comm = self.pynccl_comm
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
assert (
pynccl_comm is not None and not pynccl_comm.disabled
), "pynccl is required for all_gatherv"
def _all_gather_single(
input_: torch.Tensor, sizes: Optional[List[int]] = None
):
input_size = input_.size()
if sizes is not None:
assert len(sizes) == world_size
assert input_.shape[0] == sizes[self.rank_in_group]
output_size = (sum(sizes),) + input_size[1:]
# 'sizes' is not needed if all inputs in the same group have the same shape
if all(s == sizes[0] for s in sizes):
sizes = None
else:
output_size = (input_size[0] * world_size,) + input_size[1:]
# Allocate output tensor.
output_tensor = torch.empty(
output_size, dtype=input_.dtype, device=input_.device
)
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
return output_tensor
if isinstance(input_, torch.Tensor):
return _all_gather_single(input_, sizes)
output_list = []
pynccl_comm.group_start()
for inp in input_:
output_list.append(_all_gather_single(inp, sizes=sizes))
pynccl_comm.group_end()
return output_list
def gather(
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
) -> Optional[torch.Tensor]: