[fix] remove cuda_device_count_stateless (#5060)
This commit is contained in:
@@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
|
||||
gpu_p2p_access_check,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import in_the_same_node_as
|
||||
from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -217,7 +217,7 @@ class CustomAllreduce:
|
||||
if cuda_visible_devices:
|
||||
device_ids = list(map(int, cuda_visible_devices.split(",")))
|
||||
else:
|
||||
device_ids = list(range(cuda_device_count_stateless()))
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
physical_device_id = device_ids[device.index]
|
||||
tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu")
|
||||
|
||||
@@ -11,11 +11,11 @@ import tempfile
|
||||
from itertools import product
|
||||
from typing import Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
|
||||
from sglang.srt.utils import cuda_device_count_stateless
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -218,7 +218,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
|
||||
|
||||
is_distributed = dist.is_initialized()
|
||||
|
||||
num_dev = cuda_device_count_stateless()
|
||||
num_dev = torch.cuda.device_count()
|
||||
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
|
||||
if cuda_visible_devices is None:
|
||||
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
|
||||
|
||||
@@ -263,7 +263,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
||||
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
||||
"""
|
||||
if device == "cuda":
|
||||
num_gpus = cuda_device_count_stateless()
|
||||
num_gpus = torch.cuda.device_count()
|
||||
assert gpu_id < num_gpus
|
||||
|
||||
if torch.cuda.current_device() != gpu_id:
|
||||
@@ -1416,47 +1416,6 @@ def disable_request_logging() -> bool:
|
||||
return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
|
||||
|
||||
|
||||
@lru_cache(maxsize=8)
|
||||
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
|
||||
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
||||
# LRU Cache purposes.
|
||||
|
||||
# Code below is based on
|
||||
# https://github.com/pytorch/pytorch/blob/
|
||||
# c1cd946818442aca8c7f812b16d187ce1586c3bc/
|
||||
# torch/cuda/__init__.py#L831C1-L831C17
|
||||
import torch.version
|
||||
|
||||
if not torch.cuda._is_compiled():
|
||||
return 0
|
||||
if is_hip():
|
||||
# ROCm uses amdsmi instead of nvml for stateless device count
|
||||
# This requires a sufficiently modern version of Torch 2.4.0
|
||||
raw_count = (
|
||||
torch.cuda._device_count_amdsmi()
|
||||
if (hasattr(torch.cuda, "_device_count_amdsmi"))
|
||||
else -1
|
||||
)
|
||||
else:
|
||||
raw_count = torch.cuda._device_count_nvml()
|
||||
r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count
|
||||
return r
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py
|
||||
def cuda_device_count_stateless() -> int:
|
||||
"""Get number of CUDA devices, caching based on the value of
|
||||
CUDA_VISIBLE_DEVICES at the time of call.
|
||||
|
||||
This should be used instead of torch.cuda.device_count()
|
||||
unless CUDA_VISIBLE_DEVICES has already been set to the desired
|
||||
value."""
|
||||
|
||||
# This can be removed and simply replaced with torch.cuda.get_device_count
|
||||
# after https://github.com/pytorch/pytorch/pull/122815 is released.
|
||||
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
||||
|
||||
|
||||
def dataclass_to_string_truncated(
|
||||
data, max_length=2048, skip_names: Optional[Set[str]] = None
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user