fix cuda to npu
This commit is contained in:
7
main.py
7
main.py
@@ -7,6 +7,7 @@ import time
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import torch
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser(
|
||||
@@ -24,11 +25,15 @@ def parse_args():
|
||||
|
||||
def auto_device(user_device: str | None) -> str:
|
||||
if user_device:
|
||||
if user_device == "cuda" and not torch.cuda.is_available():
|
||||
if torch.npu.is_available():
|
||||
return "npu"
|
||||
return user_device
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if torch.npu.is_available():
|
||||
return "npu"
|
||||
except Exception:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
Reference in New Issue
Block a user