Files
enginex-bi_series-diffusers/main.py
2025-08-20 17:53:15 +08:00

197 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import re
import time
from datetime import datetime
from pathlib import Path
import torch
from diffusers import DiffusionPipeline
def safe_stem(text: str, maxlen: int = 60) -> str:
"""将提示词转为安全的文件名片段。"""
text = re.sub(r"\s+", "_", text.strip())
text = re.sub(r"[^A-Za-z0-9_\-]+", "", text)
return (text[:maxlen] or "image").strip("_")
def load_prompts(json_path: Path):
"""
支持两种 JSON 结构:
1) ["prompt 1", "prompt 2", ...]
2) [{"prompt": "...", "negative_prompt": "...", "num_inference_steps": 30,
"guidance_scale": 7.5, "seed": 42, "width": 512, "height": 512}, ...]
"""
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
prompts = []
if isinstance(data, list):
if all(isinstance(x, str) for x in data):
for s in data:
prompts.append({"prompt": s})
elif all(isinstance(x, dict) for x in data):
for obj in data:
if "prompt" not in obj:
raise ValueError("每个对象都需要包含 'prompt' 字段")
prompts.append(obj)
else:
raise ValueError("JSON 列表元素需全为字符串或全为对象。")
else:
raise ValueError("JSON 顶层必须是列表。")
return prompts
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
pipe = DiffusionPipeline.from_pretrained(
model_path,
torch_dtype=dtype,
use_safetensors=True,
)
# 设备放置
if device == "cuda" and torch.cuda.is_available():
pipe.to("cuda")
try:
pipe.enable_attention_slicing()
except Exception:
pass
# 对大模型友好;若已放到 CUDA会按需处理
try:
pipe.enable_model_cpu_offload()
except Exception:
pass
else:
pipe.to("cpu")
pipe.set_progress_bar_config(disable=True)
return pipe
def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
"""
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
支持字段:
- prompt (必需)
- negative_prompt (可选)
- num_inference_steps (默认 20)
- guidance_scale (默认 7.5)
- seed (可选)
- width, height (可选)
"""
prompt = cfg["prompt"]
negative_prompt = cfg.get("negative_prompt", None)
steps = int(cfg.get("num_inference_steps", 20))
guidance = float(cfg.get("guidance_scale", 7.5))
seed = cfg.get("seed", None)
width = cfg.get("width", None)
height = cfg.get("height", None)
# 随机数生成器(与管线设备一致)
gen = None
try:
device_str = str(getattr(pipe, "device", "cuda" if torch.cuda.is_available() else "cpu"))
except Exception:
device_str = "cuda" if torch.cuda.is_available() else "cpu"
if seed is not None:
gen = torch.Generator(device=device_str).manual_seed(int(seed))
call_kwargs = dict(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
if width is not None and height is not None:
call_kwargs.update({"width": int(width), "height": int(height)})
start = time.time()
images = pipe(**call_kwargs).images
elapsed = time.time() - start
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
stem = safe_stem(prompt)
filename = f"{index:03d}_{stem}_{stamp}.png"
out_path = out_dir / filename
images[0].save(out_path)
detail = {
"index": index,
"filename": filename,
"elapsed_seconds": round(elapsed, 6),
"prompt": prompt,
"negative_prompt": negative_prompt,
"num_inference_steps": steps,
"guidance_scale": guidance,
"seed": seed,
"width": width,
"height": height,
}
return out_path, elapsed, detail
def main():
parser = argparse.ArgumentParser(
description="Stable Diffusion 基准与批量生成脚本JSON 结果)"
)
parser.add_argument("--model", required=True, help="模型路径或模型名(本地目录或 HF 仓库名)")
parser.add_argument("--json", required=True, help="测试文本 JSON 文件路径")
parser.add_argument("--results", required=True, help="结果 JSON 文件输出路径(*.json")
parser.add_argument("--outdir", required=True, help="图片输出目录")
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
args = parser.parse_args()
model_path = args.model
json_path = Path(args.json)
results_path = Path(args.results)
out_dir = Path(args.outdir)
out_dir.mkdir(parents=True, exist_ok=True)
results_path.parent.mkdir(parents=True, exist_ok=True)
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
prompts = load_prompts(json_path)
if not prompts:
raise ValueError("测试列表为空。")
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype)
records = []
total_start = time.time()
for i, cfg in enumerate(prompts, 1):
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
records.append(detail)
total_elapsed = round(time.time() - total_start, 6)
avg_latency = total_elapsed / len(records) if records else 0
# 结果 JSON 结构
result_obj = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"model": model_path,
"device": str(getattr(pipe, "device", "unknown")),
"dtype": "fp16" if dtype == torch.float16 else "fp32",
"count": len(records),
"total_elapsed_seconds": total_elapsed,
"avg_latency": avg_latency,
"cases": records
}
with open(results_path, "w", encoding="utf-8") as f:
json.dump(result_obj, f, ensure_ascii=False, indent=2)
print(f"\nAll done. images: {len(records)}, total_elapsed: {total_elapsed:.3f}s, avg_latency: {avg_latency:.3f}")
print(f"Results JSON: {results_path}")
print(f"Images dir : {out_dir.resolve()}")
if __name__ == "__main__":
main()