Remove dependency of pynvml on ROCm (#2995)
This commit is contained in:
committed by
GitHub
parent
0ffcfdf474
commit
1a820e38a2
@@ -6,7 +6,6 @@ from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, TypeVar, Union
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if is_cuda():
|
||||
try:
|
||||
import pynvml
|
||||
except ImportError as e:
|
||||
logger.warning("Failed to import pynvml with %r", e)
|
||||
|
||||
try:
|
||||
if ops.use_vllm_custom_allreduce:
|
||||
ops.meta_size()
|
||||
|
||||
Reference in New Issue
Block a user