Make torch TP composable with torch.compile (#2352)

This commit is contained in:
Ke Wen
2024-12-04 17:26:00 -08:00
committed by GitHub
parent 18ea841f40
commit d693ec0427

View File

@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
)._prepare_output_fn( )._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh output_layouts, use_local_output, mod, outputs, device_mesh
) )
# wait for the output to be ready return torch.distributed._functional_collectives.wait_tensor(outputs)
if isinstance(outputs, AsyncCollectiveTensor):
return outputs.wait()
else:
return outputs
def tensor_parallel( def tensor_parallel(