diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 28aa9d481..d4506b9f0 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -6,7 +6,6 @@ from contextlib import contextmanager from functools import wraps from typing import Callable, List, Optional, TypeVar, Union -import pynvml import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import cuda_device_count_stateless, is_cuda +logger = logging.getLogger(__name__) + +if is_cuda(): + try: + import pynvml + except ImportError as e: + logger.warning("Failed to import pynvml with %r", e) + try: if ops.use_vllm_custom_allreduce: ops.meta_size()