fix cuda to npu

This commit is contained in:
2025-09-03 11:56:13 +08:00
parent 66156a20fa
commit 0010e9586b

View File

@@ -7,6 +7,7 @@ import time
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
import torch
def parse_args(): def parse_args():
p = argparse.ArgumentParser( p = argparse.ArgumentParser(
@@ -24,11 +25,15 @@ def parse_args():
def auto_device(user_device: str | None) -> str: def auto_device(user_device: str | None) -> str:
if user_device: if user_device:
if user_device == "cuda" and not torch.cuda.is_available():
if torch.npu.is_available():
return "npu"
return user_device return user_device
try: try:
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
return "cuda" return "cuda"
if torch.npu.is_available():
return "npu"
except Exception: except Exception:
pass pass
return "cpu" return "cpu"