Remove dependency of pynvml on ROCm (#2995)

This commit is contained in:
Chaitanya Sri Krishna Lolla
2025-01-20 10:30:35 +05:30
committed by GitHub
parent 0ffcfdf474
commit 1a820e38a2

View File

@@ -6,7 +6,6 @@ from contextlib import contextmanager
from functools import wraps from functools import wraps
from typing import Callable, List, Optional, TypeVar, Union from typing import Callable, List, Optional, TypeVar, Union
import pynvml
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.distributed.parallel_state import in_the_same_node_as
from sglang.srt.utils import cuda_device_count_stateless, is_cuda from sglang.srt.utils import cuda_device_count_stateless, is_cuda
logger = logging.getLogger(__name__)
if is_cuda():
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
try: try:
if ops.use_vllm_custom_allreduce: if ops.use_vllm_custom_allreduce:
ops.meta_size() ops.meta_size()