diff --git a/python/sglang/srt/model_parallel.py b/python/sglang/srt/model_parallel.py index 6817bce02..778347b8e 100644 --- a/python/sglang/srt/model_parallel.py +++ b/python/sglang/srt/model_parallel.py @@ -54,11 +54,7 @@ class RowwiseParallelMaybeWait(RowwiseParallel): )._prepare_output_fn( output_layouts, use_local_output, mod, outputs, device_mesh ) - # wait for the output to be ready - if isinstance(outputs, AsyncCollectiveTensor): - return outputs.wait() - else: - return outputs + return torch.distributed._functional_collectives.wait_tensor(outputs) def tensor_parallel(