add ascend support

This commit is contained in:
root
2025-09-03 10:16:24 +08:00
parent 33c57acb1b
commit 31c16831b1
4 changed files with 26 additions and 12 deletions

24
main.py
View File

@@ -54,8 +54,12 @@ def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
use_safetensors=True,
)
# 设备放置
if device == "cuda" and torch.cuda.is_available():
pipe.to("cuda")
if device == "cuda":
if torch.cuda.is_available():
pipe.to("cuda")
elif torch.npu.is_available():
pipe.to("npu")
try:
pipe.enable_attention_slicing()
except Exception:
@@ -92,17 +96,17 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
height = cfg.get("height", None)
# 随机数生成器(与管线设备一致)
gen = None
try:
device_str = str(getattr(pipe, "device", "cuda" if torch.cuda.is_available() else "cpu"))
except Exception:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
if seed is not None:
gen = torch.Generator(device=device_str).manual_seed(int(seed))
# gen = None
# try:
# device_str = str(getattr(pipe, "device", "cuda" if torch.cuda.is_available() else "npu" if torch.npu.is_available() else "cpu"))
# except Exception:
# device_str = "cuda" if torch.cuda.is_available() else "npu" if torch.npu.is_available() else "cpu"
# if seed is not None:
# gen = torch.Generator(device=device_str).manual_seed(int(seed))
call_kwargs = dict(
prompt=prompt,
generator=gen,
# generator=gen,
)
if width is not None and height is not None:
call_kwargs.update({"width": int(width), "height": int(height)})