[improve] made timeout configurable (#3803)

This commit is contained in:
Shenggui Li
2025-02-25 16:26:08 +08:00
committed by GitHub
parent 7036d6fc67
commit c0bb9eb3b3
5 changed files with 26 additions and 1 deletions

View File

@@ -81,3 +81,9 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o
- **Weight**: Per-128x128-block quantization for better numerical stability. - **Weight**: Per-128x128-block quantization for better numerical stability.
**Usage**: turn on by default for DeepSeek V3 models. **Usage**: turn on by default for DeepSeek V3 models.
## FAQ
**Question**: What should I do if model loading takes too long and NCCL timeout occurs?
Answer: You can try to add `--dist-timeout 3600` when launching the model, this allows for 1-hour timeout.i

View File

@@ -30,6 +30,7 @@ import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from datetime import timedelta
from multiprocessing import shared_memory from multiprocessing import shared_memory
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from unittest.mock import patch from unittest.mock import patch
@@ -960,6 +961,7 @@ def init_distributed_environment(
distributed_init_method: str = "env://", distributed_init_method: str = "env://",
local_rank: int = -1, local_rank: int = -1,
backend: str = "nccl", backend: str = "nccl",
timeout: Optional[int] = None,
): ):
logger.debug( logger.debug(
"world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
@@ -974,13 +976,20 @@ def init_distributed_environment(
"distributed_init_method must be provided when initializing " "distributed_init_method must be provided when initializing "
"distributed environment" "distributed environment"
) )
if timeout is not None:
assert isinstance(timeout, (int)), "timeout must be a number"
assert timeout > 0, "timeout must be positive"
timeout = timedelta(seconds=timeout)
# this backend is used for WORLD # this backend is used for WORLD
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend=backend, backend=backend,
init_method=distributed_init_method, init_method=distributed_init_method,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
timeout=timeout,
) )
# set the local rank # set the local rank
# local_rank is not available in torch ProcessGroup, # local_rank is not available in torch ProcessGroup,
# see https://github.com/pytorch/pytorch/issues/122816 # see https://github.com/pytorch/pytorch/issues/122816

View File

@@ -259,6 +259,7 @@ class ModelRunner:
rank=self.tp_rank, rank=self.tp_rank,
local_rank=self.gpu_id, local_rank=self.gpu_id,
distributed_init_method=dist_init_method, distributed_init_method=dist_init_method,
timeout=self.server_args.dist_timeout,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
initialize_dp_attention( initialize_dp_attention(

View File

@@ -79,6 +79,7 @@ class ServerArgs:
random_seed: Optional[int] = None random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None constrained_json_whitespace_pattern: Optional[str] = None
watchdog_timeout: float = 300 watchdog_timeout: float = 300
dist_timeout: Optional[int] = None # timeout for torch.distributed
download_dir: Optional[str] = None download_dir: Optional[str] = None
base_gpu_id: int = 0 base_gpu_id: int = 0
@@ -534,6 +535,12 @@ class ServerArgs:
default=ServerArgs.watchdog_timeout, default=ServerArgs.watchdog_timeout,
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.", help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
) )
parser.add_argument(
"--dist-timeout",
type=int,
default=ServerArgs.dist_timeout,
help="Set timeout for torch.distributed initialization.",
)
parser.add_argument( parser.add_argument(
"--download-dir", "--download-dir",
type=str, type=str,

View File

@@ -503,7 +503,9 @@ def run_unittest_files(files: List[str], timeout_per_file: float):
ret_code = run_with_timeout( ret_code = run_with_timeout(
run_one_file, args=(filename,), timeout=timeout_per_file run_one_file, args=(filename,), timeout=timeout_per_file
) )
assert ret_code == 0 assert (
ret_code == 0
), f"expected return code 0, but {filename} returned {ret_code}"
except TimeoutError: except TimeoutError:
kill_process_tree(process.pid) kill_process_tree(process.pid)
time.sleep(5) time.sleep(5)