diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index 813ae0122..8d81e47a1 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -18,7 +18,7 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import gpu_p2p_access_check, ) from sglang.srt.distributed.parallel_state import in_the_same_node_as -from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip +from sglang.srt.utils import is_cuda, is_hip logger = logging.getLogger(__name__) @@ -217,7 +217,7 @@ class CustomAllreduce: if cuda_visible_devices: device_ids = list(map(int, cuda_visible_devices.split(","))) else: - device_ids = list(range(cuda_device_count_stateless())) + device_ids = list(range(torch.cuda.device_count())) physical_device_id = device_ids[device.index] tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index 4073491aa..86121ac97 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -11,11 +11,11 @@ import tempfile from itertools import product from typing import Dict, List, Optional, Sequence +import torch import torch.distributed as dist import torch.multiprocessing as mp from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary -from sglang.srt.utils import cuda_device_count_stateless logger = logging.getLogger(__name__) @@ -218,7 +218,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: is_distributed = dist.is_initialized() - num_dev = cuda_device_count_stateless() + num_dev = torch.cuda.device_count() cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) if cuda_visible_devices is None: cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ba229b1ce..b43fe4273 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -263,7 +263,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True When distributed is True, the available memory is the minimum available memory of all GPUs. """ if device == "cuda": - num_gpus = cuda_device_count_stateless() + num_gpus = torch.cuda.device_count() assert gpu_id < num_gpus if torch.cuda.current_device() != gpu_id: @@ -1416,47 +1416,6 @@ def disable_request_logging() -> bool: return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING") -@lru_cache(maxsize=8) -def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int: - # Note: cuda_visible_devices is not used, but we keep it as an argument for - # LRU Cache purposes. - - # Code below is based on - # https://github.com/pytorch/pytorch/blob/ - # c1cd946818442aca8c7f812b16d187ce1586c3bc/ - # torch/cuda/__init__.py#L831C1-L831C17 - import torch.version - - if not torch.cuda._is_compiled(): - return 0 - if is_hip(): - # ROCm uses amdsmi instead of nvml for stateless device count - # This requires a sufficiently modern version of Torch 2.4.0 - raw_count = ( - torch.cuda._device_count_amdsmi() - if (hasattr(torch.cuda, "_device_count_amdsmi")) - else -1 - ) - else: - raw_count = torch.cuda._device_count_nvml() - r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count - return r - - -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/utils.py -def cuda_device_count_stateless() -> int: - """Get number of CUDA devices, caching based on the value of - CUDA_VISIBLE_DEVICES at the time of call. - - This should be used instead of torch.cuda.device_count() - unless CUDA_VISIBLE_DEVICES has already been set to the desired - value.""" - - # This can be removed and simply replaced with torch.cuda.get_device_count - # after https://github.com/pytorch/pytorch/pull/122815 is released. - return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None)) - - def dataclass_to_string_truncated( data, max_length=2048, skip_names: Optional[Set[str]] = None ):