Files
2026-04-24 09:58:03 +08:00

115 lines
3.8 KiB
Python

import logging
import threading
from typing import Optional
import ray
import ray._private.ray_constants as ray_constants
from ray.air._internal.device_manager.cpu import CPUTorchDeviceManager
from ray.air._internal.device_manager.hpu import HPUTorchDeviceManager
from ray.air._internal.device_manager.npu import NPUTorchDeviceManager
from ray.air._internal.device_manager.mlu import MLUTorchDeviceManager
from ray.air._internal.device_manager.nvidia_gpu import CUDATorchDeviceManager
from ray.air._internal.device_manager.torch_device_manager import TorchDeviceManager
logger = logging.getLogger(__name__)
DEFAULT_TORCH_DEVICE_MANAGER_CLS = CPUTorchDeviceManager
'''
=============================
Modify by vllm_mlu
=============================
@brief: use MLUTorchDeviceManager when key="GPU"
'''
SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER = {
ray_constants.GPU: MLUTorchDeviceManager,
ray_constants.HPU: HPUTorchDeviceManager,
ray_constants.NPU: NPUTorchDeviceManager,
}
'''
==================
End of MLU Hijack
==================
'''
def register_custom_torch_dist_backend(backend: Optional[str] = None) -> None:
if backend == "hccl":
# The name for the communication backend of Habana and torch-npu is the same.
HPUTorchDeviceManager.register_custom_torch_dist_backend()
NPUTorchDeviceManager.register_custom_torch_dist_backend()
_torch_device_manager = None
_torch_device_manager_lock = threading.Lock()
def get_torch_device_manager_by_context() -> TorchDeviceManager:
global _torch_device_manager
with _torch_device_manager_lock:
if not _torch_device_manager:
existing_device_manager_cls = None
resources = ray.get_runtime_context().get_accelerator_ids()
# select correct accelerator type from resources
for resource_type, resource_value in resources.items():
device_manager_cls = SUPPORTED_ACCELERATOR_TORCH_DEVICE_MANAGER.get(
resource_type, None
)
if resource_value and device_manager_cls:
# An error will raise when multiple accelerators are specified.
if existing_device_manager_cls:
raise RuntimeError(
"Unable to determine the appropriate DeviceManager "
f"for the specified resources {resources}."
)
else:
existing_device_manager_cls = device_manager_cls
device_manager_cls = (
existing_device_manager_cls or DEFAULT_TORCH_DEVICE_MANAGER_CLS
)
_torch_device_manager = device_manager_cls()
return _torch_device_manager
def get_torch_device_manager_by_device_type(device_type: str):
'''
=============================
Modify by vllm_mlu
=============================
@brief: use MLUTorchDeviceManager when key="GPU"
'''
if device_type.lower() == ray_constants.GPU.lower() or device_type == "cuda":
return MLUTorchDeviceManager()
elif device_type.lower() == ray_constants.NPU.lower():
return NPUTorchDeviceManager()
elif device_type.lower() == ray_constants.HPU.lower():
return HPUTorchDeviceManager()
elif device_type.lower() == "cpu":
return CPUTorchDeviceManager()
'''
==================
End of MLU Hijack
==================
'''
raise RuntimeError(f"Device type {device_type} cannot be recognized.")
__all__ = [
TorchDeviceManager,
CPUTorchDeviceManager,
CUDATorchDeviceManager,
HPUTorchDeviceManager,
NPUTorchDeviceManager,
MLUTorchDeviceManager,
register_custom_torch_dist_backend,
get_torch_device_manager_by_context,
get_torch_device_manager_by_device_type,
]