From 1a820e38a2fcc6d0e0324605bb39baec23d81f8d Mon Sep 17 00:00:00 2001 From: Chaitanya Sri Krishna Lolla Date: Mon, 20 Jan 2025 10:30:35 +0530 Subject: [PATCH] Remove dependency of pynvml on ROCm (#2995) --- .../device_communicators/custom_all_reduce.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 28aa9d481..d4506b9f0 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -6,7 +6,6 @@ from contextlib import contextmanager from functools import wraps from typing import Callable, List, Optional, TypeVar, Union -import pynvml import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import cuda_device_count_stateless, is_cuda +logger = logging.getLogger(__name__) + +if is_cuda(): + try: + import pynvml + except ImportError as e: + logger.warning("Failed to import pynvml with %r", e) + try: if ops.use_vllm_custom_allreduce: ops.meta_size()