Remove dependency of pynvml on ROCm (#2995)

This commit is contained in:
Chaitanya Sri Krishna Lolla
2025-01-20 10:30:35 +05:30
committed by GitHub
parent 0ffcfdf474
commit 1a820e38a2

View File

@@ -6,7 +6,6 @@ from contextlib import contextmanager
from functools import wraps
from typing import Callable, List, Optional, TypeVar, Union
import pynvml
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
@@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
logger = logging.getLogger(__name__)
if is_cuda():
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
try:
if ops.use_vllm_custom_allreduce:
ops.meta_size()