[improve] made timeout configurable (#3803)
This commit is contained in:
@@ -30,6 +30,7 @@ import weakref
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from unittest.mock import patch
|
||||
@@ -960,6 +961,7 @@ def init_distributed_environment(
|
||||
distributed_init_method: str = "env://",
|
||||
local_rank: int = -1,
|
||||
backend: str = "nccl",
|
||||
timeout: Optional[int] = None,
|
||||
):
|
||||
logger.debug(
|
||||
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
|
||||
@@ -974,13 +976,20 @@ def init_distributed_environment(
|
||||
"distributed_init_method must be provided when initializing "
|
||||
"distributed environment"
|
||||
)
|
||||
if timeout is not None:
|
||||
assert isinstance(timeout, (int)), "timeout must be a number"
|
||||
assert timeout > 0, "timeout must be positive"
|
||||
timeout = timedelta(seconds=timeout)
|
||||
|
||||
# this backend is used for WORLD
|
||||
torch.distributed.init_process_group(
|
||||
backend=backend,
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# set the local rank
|
||||
# local_rank is not available in torch ProcessGroup,
|
||||
# see https://github.com/pytorch/pytorch/issues/122816
|
||||
|
||||
Reference in New Issue
Block a user