fix cuda to npu
This commit is contained in:
7
main.py
7
main.py
@@ -7,6 +7,7 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
|
import torch
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
p = argparse.ArgumentParser(
|
p = argparse.ArgumentParser(
|
||||||
@@ -24,11 +25,15 @@ def parse_args():
|
|||||||
|
|
||||||
def auto_device(user_device: str | None) -> str:
|
def auto_device(user_device: str | None) -> str:
|
||||||
if user_device:
|
if user_device:
|
||||||
|
if user_device == "cuda" and not torch.cuda.is_available():
|
||||||
|
if torch.npu.is_available():
|
||||||
|
return "npu"
|
||||||
return user_device
|
return user_device
|
||||||
try:
|
try:
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
if torch.npu.is_available():
|
||||||
|
return "npu"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
return "cpu"
|
return "cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user