Remove dependency of pynvml on ROCm (#2995)
This commit is contained in:
committed by
GitHub
parent
0ffcfdf474
commit
1a820e38a2
@@ -6,7 +6,6 @@ from contextlib import contextmanager
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, List, Optional, TypeVar, Union
|
from typing import Callable, List, Optional, TypeVar, Union
|
||||||
|
|
||||||
import pynvml
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@@ -20,6 +19,14 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
|
|||||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
|
from sglang.srt.utils import cuda_device_count_stateless, is_cuda
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if is_cuda():
|
||||||
|
try:
|
||||||
|
import pynvml
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning("Failed to import pynvml with %r", e)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if ops.use_vllm_custom_allreduce:
|
if ops.use_vllm_custom_allreduce:
|
||||||
ops.meta_size()
|
ops.meta_size()
|
||||||
|
|||||||
Reference in New Issue
Block a user