add ascend support
This commit is contained in:
24
main.py
24
main.py
@@ -54,8 +54,12 @@ def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
|
||||
use_safetensors=True,
|
||||
)
|
||||
# 设备放置
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
pipe.to("cuda")
|
||||
if device == "cuda":
|
||||
if torch.cuda.is_available():
|
||||
pipe.to("cuda")
|
||||
elif torch.npu.is_available():
|
||||
pipe.to("npu")
|
||||
|
||||
try:
|
||||
pipe.enable_attention_slicing()
|
||||
except Exception:
|
||||
@@ -92,17 +96,17 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
|
||||
height = cfg.get("height", None)
|
||||
|
||||
# 随机数生成器(与管线设备一致)
|
||||
gen = None
|
||||
try:
|
||||
device_str = str(getattr(pipe, "device", "cuda" if torch.cuda.is_available() else "cpu"))
|
||||
except Exception:
|
||||
device_str = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if seed is not None:
|
||||
gen = torch.Generator(device=device_str).manual_seed(int(seed))
|
||||
# gen = None
|
||||
# try:
|
||||
# device_str = str(getattr(pipe, "device", "cuda" if torch.cuda.is_available() else "npu" if torch.npu.is_available() else "cpu"))
|
||||
# except Exception:
|
||||
# device_str = "cuda" if torch.cuda.is_available() else "npu" if torch.npu.is_available() else "cpu"
|
||||
# if seed is not None:
|
||||
# gen = torch.Generator(device=device_str).manual_seed(int(seed))
|
||||
|
||||
call_kwargs = dict(
|
||||
prompt=prompt,
|
||||
generator=gen,
|
||||
# generator=gen,
|
||||
)
|
||||
if width is not None and height is not None:
|
||||
call_kwargs.update({"width": int(width), "height": int(height)})
|
||||
|
||||
Reference in New Issue
Block a user