diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index acc65a0f7..ee4ac8685 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -37,6 +37,7 @@ import time import warnings from functools import lru_cache from importlib.metadata import PackageNotFoundError, version +from importlib.util import find_spec from io import BytesIO from multiprocessing import Pool from multiprocessing.reduction import ForkingPickler @@ -1051,6 +1052,65 @@ def get_device_name(device_id: int = 0) -> str: return torch.hpu.get_device_name(device_id) +@lru_cache(maxsize=1) +def is_habana_available() -> bool: + return find_spec("habana_frameworks") is not None + + +@lru_cache(maxsize=8) +def get_device(device_id: Optional[int] = None) -> str: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + if device_id is None: + return "cuda" + return "cuda:{}".format(device_id) + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + if device_id == None: + return "xpu" + return "xpu:{}".format(device_id) + + if is_habana_available(): + try: + import habana_frameworks.torch.hpu + + if torch.hpu.is_available(): + if device_id == None: + return "hpu" + return "hpu:{}".format(device_id) + except ImportError as e: + raise ImportError( + "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'." + ) + + raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.") + + +@lru_cache(maxsize=1) +def get_device_count() -> int: + if hasattr(torch, "cuda") and torch.cuda.is_available(): + try: + return torch.cuda.device_count() + except RuntimeError: + return 0 + + if hasattr(torch, "xpu") and torch.xpu.is_available(): + try: + return torch.xpu.device_count() + except RuntimeError: + return 0 + + if is_habana_available(): + try: + import habana_frameworks.torch.hpu + + if torch.hpu.is_available(): + return torch.hpu.device_count() + except (ImportError, RuntimeError): + return 0 + + return 0 # No accelerators available + + def get_device_core_count(device_id: int = 0) -> int: if hasattr(torch, "cuda") and torch.cuda.is_available(): return torch.cuda.get_device_properties(device_id).multi_processor_count @@ -1069,11 +1129,12 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: ) major, minor = int(major), int(minor) - # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now. - # Update this once the support is available. if hasattr(torch, "hpu") and torch.hpu.is_available(): try: - major, minor = torch.hpu.get_device_capability(device_id) + # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now. + # Update this once the support is available. + # major, minor = torch.hpu.get_device_capability(device_id) + major, minor = None, None except Exception as e: raise RuntimeError( f"An error occurred while getting device capability of hpu: {e}."