diff --git a/main.py b/main.py index b0cc860..a5b06c9 100644 --- a/main.py +++ b/main.py @@ -7,6 +7,7 @@ import time from pathlib import Path import numpy as np from sentence_transformers import SentenceTransformer +import torch def parse_args(): p = argparse.ArgumentParser( @@ -24,11 +25,15 @@ def parse_args(): def auto_device(user_device: str | None) -> str: if user_device: + if user_device == "cuda" and not torch.cuda.is_available(): + if torch.npu.is_available(): + return "npu" return user_device try: - import torch if torch.cuda.is_available(): return "cuda" + if torch.npu.is_available(): + return "npu" except Exception: pass return "cpu"