[fix] put cpu in the first priority in get_device() (#7752)
This commit is contained in:
@@ -1443,6 +1443,15 @@ def is_habana_available() -> bool:
|
|||||||
|
|
||||||
@lru_cache(maxsize=8)
|
@lru_cache(maxsize=8)
|
||||||
def get_device(device_id: Optional[int] = None) -> str:
|
def get_device(device_id: Optional[int] = None) -> str:
|
||||||
|
if is_cpu():
|
||||||
|
if cpu_has_amx_support():
|
||||||
|
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"CPU device enabled, using torch native backend, low performance expected."
|
||||||
|
)
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
||||||
if device_id is None:
|
if device_id is None:
|
||||||
return "cuda"
|
return "cuda"
|
||||||
@@ -1471,15 +1480,6 @@ def get_device(device_id: Optional[int] = None) -> str:
|
|||||||
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
|
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_cpu():
|
|
||||||
if cpu_has_amx_support():
|
|
||||||
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"CPU device enabled, using torch native backend, low performance expected."
|
|
||||||
)
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
|
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user