Make torch TP composable with torch.compile (#2352)

This commit is contained in:
Ke Wen
2024-12-04 17:26:00 -08:00
committed by GitHub
parent 18ea841f40
commit d693ec0427

View File

@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
)._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh
)
# wait for the output to be ready
if isinstance(outputs, AsyncCollectiveTensor):
return outputs.wait()
else:
return outputs
return torch.distributed._functional_collectives.wait_tensor(outputs)
def tensor_parallel(