[improve] made timeout configurable (#3803)

This commit is contained in:
Shenggui Li
2025-02-25 16:26:08 +08:00
committed by GitHub
parent 7036d6fc67
commit c0bb9eb3b3
5 changed files with 26 additions and 1 deletions

View File

@@ -30,6 +30,7 @@ import weakref
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch
@@ -960,6 +961,7 @@ def init_distributed_environment(
distributed_init_method: str = "env://",
local_rank: int = -1,
backend: str = "nccl",
timeout: Optional[int] = None,
):
logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
@@ -974,13 +976,20 @@ def init_distributed_environment(
"distributed_init_method must be provided when initializing "
"distributed environment"
)
if timeout is not None:
assert isinstance(timeout, (int)), "timeout must be a number"
assert timeout > 0, "timeout must be positive"
timeout = timedelta(seconds=timeout)
# this backend is used for WORLD
torch.distributed.init_process_group(
backend=backend,
init_method=distributed_init_method,
world_size=world_size,
rank=rank,
timeout=timeout,
)
# set the local rank
# local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816