Make torch TP composable with torch.compile (#2352)
This commit is contained in:
@@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
|
|||||||
)._prepare_output_fn(
|
)._prepare_output_fn(
|
||||||
output_layouts, use_local_output, mod, outputs, device_mesh
|
output_layouts, use_local_output, mod, outputs, device_mesh
|
||||||
)
|
)
|
||||||
# wait for the output to be ready
|
return torch.distributed._functional_collectives.wait_tensor(outputs)
|
||||||
if isinstance(outputs, AsyncCollectiveTensor):
|
|
||||||
return outputs.wait()
|
|
||||||
else:
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
def tensor_parallel(
|
def tensor_parallel(
|
||||||
|
|||||||
Reference in New Issue
Block a user