fix cuda to npu

This commit is contained in:
2025-09-03 11:56:13 +08:00
parent 66156a20fa
commit 0010e9586b

View File

@@ -7,6 +7,7 @@ import time
from pathlib import Path
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
def parse_args():
p = argparse.ArgumentParser(
@@ -24,11 +25,15 @@ def parse_args():
def auto_device(user_device: str | None) -> str:
if user_device:
if user_device == "cuda" and not torch.cuda.is_available():
if torch.npu.is_available():
return "npu"
return user_device
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.npu.is_available():
return "npu"
except Exception:
pass
return "cpu"