Add graph runner support with torch compile on CPU (#7843)
This commit is contained in:
@@ -64,6 +64,9 @@ class GraphCaptureContext:
|
||||
|
||||
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
||||
|
||||
# use int value instead of ReduceOp.SUM to support torch compile
|
||||
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
||||
|
||||
|
||||
def _split_tensor_dict(
|
||||
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
||||
@@ -489,9 +492,7 @@ class GroupCoordinator:
|
||||
|
||||
if input_.is_cpu:
|
||||
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
||||
torch.ops.sgl_kernel.shm_allreduce(
|
||||
input_, torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
|
||||
else:
|
||||
torch.distributed.all_reduce(input_, group=self.device_group)
|
||||
return input_
|
||||
|
||||
Reference in New Issue
Block a user