Make torch TP composable with torch.compile (#2352)
This commit is contained in:
@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
|
||||
)._prepare_output_fn(
|
||||
output_layouts, use_local_output, mod, outputs, device_mesh
|
||||
)
|
||||
# wait for the output to be ready
|
||||
if isinstance(outputs, AsyncCollectiveTensor):
|
||||
return outputs.wait()
|
||||
else:
|
||||
return outputs
|
||||
return torch.distributed._functional_collectives.wait_tensor(outputs)
|
||||
|
||||
|
||||
def tensor_parallel(
|
||||
|
||||
Reference in New Issue
Block a user