115 lines
3.8 KiB
Python
115 lines
3.8 KiB
Python
import logging
|
|
import threading
|
|
from typing import Optional
|
|
|
|
import ray
|
|
import ray._private.ray_constants as ray_constants
|
|
from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager
|
|
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager
|
|
from ray.air._internal.device_manager.npu import NPUTorchDeviceManager
|
|
from ray.air._internal.device_manager.mlu import MLUTorchDeviceManager
|
|
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager
|
|
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager
|
|
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: use MLUTorchDeviceManager when key="GPU"
|
|
'''
|
|
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = {
|
|
ray_constants.GPU: MLUTorchDeviceManager,
|
|
ray_constants.HPU: HPUTorchDeviceManager,
|
|
ray_constants.NPU: NPUTorchDeviceManager,
|
|
}
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
|
|
|
|
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None:
|
|
if backend == "hccl":
|
|
# The name for the communication backend of Habana and torch-npu is the same.
|
|
HPUTorchDeviceManager.register_custom_torch_dist_backend()
|
|
|
|
NPUTorchDeviceManager.register_custom_torch_dist_backend()
|
|
|
|
|
|
_torch_device_manager = None
|
|
_torch_device_manager_lock = threading.Lock()
|
|
|
|
|
|
def get_torch_device_manager_by_context() -> TorchDeviceManager:
|
|
global _torch_device_manager
|
|
|
|
with _torch_device_manager_lock:
|
|
if not _torch_device_manager:
|
|
existing_device_manager_cls = None
|
|
resources = ray.get_runtime_context().get_accelerator_ids()
|
|
|
|
# select correct accelerator type from resources
|
|
for resource_type, resource_value in resources.items():
|
|
device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get(
|
|
resource_type, None
|
|
)
|
|
if resource_value and device_manager_cls:
|
|
# An error will raise when multiple accelerators are specified.
|
|
if existing_device_manager_cls:
|
|
raise RuntimeError(
|
|
"Unable to determine the appropriate DeviceManager "
|
|
f"for the specified resources {resources}."
|
|
)
|
|
else:
|
|
existing_device_manager_cls = device_manager_cls
|
|
|
|
device_manager_cls = (
|
|
existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS
|
|
)
|
|
|
|
_torch_device_manager = device_manager_cls()
|
|
|
|
return _torch_device_manager
|
|
|
|
|
|
def get_torch_device_manager_by_device_type(device_type: str):
|
|
'''
|
|
=============================
|
|
Modify by vllm_mlu
|
|
=============================
|
|
@brief: use MLUTorchDeviceManager when key="GPU"
|
|
'''
|
|
if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda":
|
|
return MLUTorchDeviceManager()
|
|
elif device_type.lower() == ray_constants.NPU.lower():
|
|
return NPUTorchDeviceManager()
|
|
elif device_type.lower() == ray_constants.HPU.lower():
|
|
return HPUTorchDeviceManager()
|
|
elif device_type.lower() == "cpu":
|
|
return CPUTorchDeviceManager()
|
|
'''
|
|
==================
|
|
End of MLU Hijack
|
|
==================
|
|
'''
|
|
raise RuntimeError(f"Device type {device_type} cannot be recognized.")
|
|
|
|
|
|
__all__ = [
|
|
TorchDeviceManager,
|
|
CPUTorchDeviceManager,
|
|
CUDATorchDeviceManager,
|
|
HPUTorchDeviceManager,
|
|
NPUTorchDeviceManager,
|
|
MLUTorchDeviceManager,
|
|
register_custom_torch_dist_backend,
|
|
get_torch_device_manager_by_context,
|
|
get_torch_device_manager_by_device_type,
|
|
]
|