Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)
This commit is contained in:
@@ -148,7 +148,11 @@ class PyNcclCommunicator:
|
||||
)
|
||||
|
||||
def all_gather(
|
||||
self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
|
||||
self,
|
||||
output_tensor: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
dst_slice = output_tensor[split_offset : split_offset + split_size]
|
||||
self.nccl.ncclBroadcast(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(dst_slice.data_ptr()),
|
||||
dst_slice.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
else:
|
||||
self.nccl.ncclAllGather(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
input_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def reduce_scatter(
|
||||
self,
|
||||
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
|
||||
input_tensor: torch.Tensor,
|
||||
op: ReduceOp = ReduceOp.SUM,
|
||||
stream=None,
|
||||
sizes: Optional[list[int]] = None,
|
||||
):
|
||||
if self.disabled:
|
||||
return
|
||||
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
|
||||
)
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
if sizes is not None:
|
||||
split_offset = 0
|
||||
self.nccl.ncclGroupStart()
|
||||
for root, split_size in enumerate(sizes):
|
||||
chunk = input_tensor[split_offset : split_offset + split_size, ...]
|
||||
|
||||
self.nccl.ncclReduce(
|
||||
buffer_type(chunk.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
chunk.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
root,
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
split_offset += split_size
|
||||
self.nccl.ncclGroupEnd()
|
||||
else:
|
||||
self.nccl.ncclReduceScatter(
|
||||
buffer_type(input_tensor.data_ptr()),
|
||||
buffer_type(output_tensor.data_ptr()),
|
||||
output_tensor.numel(),
|
||||
ncclDataTypeEnum.from_torch(input_tensor.dtype),
|
||||
ncclRedOpTypeEnum.from_torch(op),
|
||||
self.comm,
|
||||
cudaStream_t(stream.cuda_stream),
|
||||
)
|
||||
|
||||
def send(self, tensor: torch.Tensor, dst: int, stream=None):
|
||||
if self.disabled:
|
||||
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
|
||||
def deregister_comm_window(self, window):
|
||||
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
||||
|
||||
def group_start(self):
|
||||
self.nccl.ncclGroupStart()
|
||||
|
||||
def group_end(self):
|
||||
self.nccl.ncclGroupEnd()
|
||||
|
||||
@contextmanager
|
||||
def change_state(
|
||||
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
||||
|
||||
@@ -206,6 +206,26 @@ class NCCLLibrary:
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduce(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, int root,
|
||||
# ncclComm_t comm, cudaStream_t stream);
|
||||
# note that cudaStream_t is a pointer type, so the last argument
|
||||
# is a pointer
|
||||
Function(
|
||||
"ncclReduce",
|
||||
ncclResult_t,
|
||||
[
|
||||
buffer_type,
|
||||
buffer_type,
|
||||
ctypes.c_size_t,
|
||||
ncclDataType_t,
|
||||
ncclRedOp_t,
|
||||
ctypes.c_int,
|
||||
ncclComm_t,
|
||||
cudaStream_t,
|
||||
],
|
||||
),
|
||||
# ncclResult_t ncclReduceScatter(
|
||||
# const void* sendbuff, void* recvbuff, size_t count,
|
||||
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
|
||||
@@ -278,6 +298,10 @@ class NCCLLibrary:
|
||||
# it is better not to call it at all.
|
||||
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
|
||||
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
||||
# ncclResult_t ncclGroupStart();
|
||||
Function("ncclGroupStart", ncclResult_t, []),
|
||||
# ncclResult_t ncclGroupEnd();
|
||||
Function("ncclGroupEnd", ncclResult_t, []),
|
||||
]
|
||||
|
||||
exported_functions_symm_mem = [
|
||||
@@ -400,6 +424,28 @@ class NCCLLibrary:
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduce(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
recvbuff: buffer_type,
|
||||
count: int,
|
||||
datatype: int,
|
||||
op: int,
|
||||
root: int,
|
||||
comm: ncclComm_t,
|
||||
stream: cudaStream_t,
|
||||
) -> None:
|
||||
# `datatype` actually should be `ncclDataType_t`
|
||||
# and `op` should be `ncclRedOp_t`
|
||||
# both are aliases of `ctypes.c_int`
|
||||
# when we pass int to a function, it will be converted to `ctypes.c_int`
|
||||
# by ctypes automatically
|
||||
self.NCCL_CHECK(
|
||||
self._funcs["ncclReduce"](
|
||||
sendbuff, recvbuff, count, datatype, op, root, comm, stream
|
||||
)
|
||||
)
|
||||
|
||||
def ncclReduceScatter(
|
||||
self,
|
||||
sendbuff: buffer_type,
|
||||
@@ -499,6 +545,12 @@ class NCCLLibrary:
|
||||
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
||||
|
||||
def ncclGroupStart(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupStart"]())
|
||||
|
||||
def ncclGroupEnd(self) -> None:
|
||||
self.NCCL_CHECK(self._funcs["ncclGroupEnd"]())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"NCCLLibrary",
|
||||
|
||||
@@ -583,6 +583,39 @@ class GroupCoordinator:
|
||||
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
|
||||
return output
|
||||
|
||||
def reduce_scatterv(
|
||||
self,
|
||||
input_: torch.Tensor,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
sizes: Optional[List[int]] = None,
|
||||
) -> torch.Tensor:
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
|
||||
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
|
||||
assert (
|
||||
pynccl_comm is not None and not pynccl_comm.disabled
|
||||
), "pynccl is required for reduce_scatterv"
|
||||
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[0] == sum(sizes)
|
||||
chunk_size = sizes[self.rank_in_group]
|
||||
else:
|
||||
assert input_.shape[0] % world_size == 0
|
||||
chunk_size = input_.shape[0] // world_size
|
||||
output_shape = (chunk_size,) + input_.shape[1:]
|
||||
|
||||
if output is None:
|
||||
output = torch.empty(
|
||||
output_shape, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
else:
|
||||
assert output.shape == output_shape
|
||||
|
||||
pynccl_comm.reduce_scatter(output, input_, sizes=sizes)
|
||||
return output
|
||||
|
||||
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
||||
pynccl_comm = self.pynccl_comm
|
||||
if pynccl_comm is not None and not pynccl_comm.disabled:
|
||||
@@ -673,6 +706,54 @@ class GroupCoordinator:
|
||||
)
|
||||
return output_tensor
|
||||
|
||||
def all_gatherv(
|
||||
self,
|
||||
input_: Union[torch.Tensor, List[torch.Tensor]],
|
||||
sizes: Optional[List[int]] = None,
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
Supports varying sizes per rank and input tensor list.
|
||||
`sizes`: a list of len(world_size) with the number of items per rank to gather.
|
||||
"""
|
||||
world_size = self.world_size
|
||||
pynccl_comm = self.pynccl_comm
|
||||
|
||||
with pynccl_comm.change_state(enable=True, stream=torch.cuda.current_stream()):
|
||||
assert (
|
||||
pynccl_comm is not None and not pynccl_comm.disabled
|
||||
), "pynccl is required for all_gatherv"
|
||||
|
||||
def _all_gather_single(
|
||||
input_: torch.Tensor, sizes: Optional[List[int]] = None
|
||||
):
|
||||
input_size = input_.size()
|
||||
if sizes is not None:
|
||||
assert len(sizes) == world_size
|
||||
assert input_.shape[0] == sizes[self.rank_in_group]
|
||||
output_size = (sum(sizes),) + input_size[1:]
|
||||
# 'sizes' is not needed if all inputs in the same group have the same shape
|
||||
if all(s == sizes[0] for s in sizes):
|
||||
sizes = None
|
||||
else:
|
||||
output_size = (input_size[0] * world_size,) + input_size[1:]
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty(
|
||||
output_size, dtype=input_.dtype, device=input_.device
|
||||
)
|
||||
pynccl_comm.all_gather(output_tensor, input_, sizes=sizes)
|
||||
return output_tensor
|
||||
|
||||
if isinstance(input_, torch.Tensor):
|
||||
return _all_gather_single(input_, sizes)
|
||||
|
||||
output_list = []
|
||||
pynccl_comm.group_start()
|
||||
for inp in input_:
|
||||
output_list.append(_all_gather_single(inp, sizes=sizes))
|
||||
pynccl_comm.group_end()
|
||||
|
||||
return output_list
|
||||
|
||||
def gather(
|
||||
self, input_: torch.Tensor, dst: int = 0, dim: int = -1
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
Reference in New Issue
Block a user