#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import json import os import re import time from datetime import datetime from pathlib import Path import torch from diffusers import DiffusionPipeline def safe_stem(text: str, maxlen: int = 60) -> str: """将提示词转为安全的文件名片段。""" text = re.sub(r"\s+", "_", text.strip()) text = re.sub(r"[^A-Za-z0-9_\-]+", "", text) return (text[:maxlen] or "image").strip("_") def load_prompts(json_path: Path): """ 支持两种 JSON 结构: 1) ["prompt 1", "prompt 2", ...] 2) [{"prompt": "...", "negative_prompt": "...", "num_inference_steps": 30, "guidance_scale": 7.5, "seed": 42, "width": 512, "height": 512}, ...] """ with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) prompts = [] if isinstance(data, list): if all(isinstance(x, str) for x in data): for s in data: prompts.append({"prompt": s}) elif all(isinstance(x, dict) for x in data): for obj in data: if "prompt" not in obj: raise ValueError("每个对象都需要包含 'prompt' 字段") prompts.append(obj) else: raise ValueError("JSON 列表元素需全为字符串或全为对象。") else: raise ValueError("JSON 顶层必须是列表。") return prompts def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16): pipe = DiffusionPipeline.from_pretrained( model_path, torch_dtype=dtype, use_safetensors=True, ) # 设备放置 if device == "cuda" and torch.cuda.is_available(): pipe.to("cuda") try: pipe.enable_attention_slicing() except Exception: pass # 对大模型友好;若已放到 CUDA,会按需处理 try: pipe.enable_model_cpu_offload() except Exception: pass else: pipe.to("cpu") pipe.set_progress_bar_config(disable=True) return pipe def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int): """ 依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数) 支持字段: - prompt (必需) - negative_prompt (可选) - num_inference_steps (默认 20) - guidance_scale (默认 7.5) - seed (可选) - width, height (可选) """ prompt = cfg["prompt"] negative_prompt = cfg.get("negative_prompt", None) steps = int(cfg.get("num_inference_steps", 20)) guidance = float(cfg.get("guidance_scale", 7.5)) seed = cfg.get("seed", None) width = cfg.get("width", None) height = cfg.get("height", None) # 随机数生成器(与管线设备一致) gen = None try: device_str = str(getattr(pipe, "device", "cuda" if torch.cuda.is_available() else "cpu")) except Exception: device_str = "cuda" if torch.cuda.is_available() else "cpu" if seed is not None: gen = torch.Generator(device=device_str).manual_seed(int(seed)) call_kwargs = dict( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=steps, guidance_scale=guidance, generator=gen, ) if width is not None and height is not None: call_kwargs.update({"width": int(width), "height": int(height)}) start = time.time() images = pipe(**call_kwargs).images elapsed = time.time() - start stamp = datetime.now().strftime("%Y%m%d-%H%M%S") stem = safe_stem(prompt) filename = f"{index:03d}_{stem}_{stamp}.png" out_path = out_dir / filename images[0].save(out_path) detail = { "index": index, "filename": filename, "elapsed_seconds": round(elapsed, 6), "prompt": prompt, "negative_prompt": negative_prompt, "num_inference_steps": steps, "guidance_scale": guidance, "seed": seed, "width": width, "height": height, } return out_path, elapsed, detail def main(): parser = argparse.ArgumentParser( description="Stable Diffusion 基准与批量生成脚本(JSON 结果)" ) parser.add_argument("--model", required=True, help="模型路径或模型名(本地目录或 HF 仓库名)") parser.add_argument("--json", required=True, help="测试文本 JSON 文件路径") parser.add_argument("--results", required=True, help="结果 JSON 文件输出路径(*.json)") parser.add_argument("--outdir", required=True, help="图片输出目录") parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备") parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度") args = parser.parse_args() model_path = args.model json_path = Path(args.json) results_path = Path(args.results) out_dir = Path(args.outdir) out_dir.mkdir(parents=True, exist_ok=True) results_path.parent.mkdir(parents=True, exist_ok=True) dtype = torch.float16 if args.dtype == "fp16" else torch.float32 prompts = load_prompts(json_path) if not prompts: raise ValueError("测试列表为空。") pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype) records = [] total_start = time.time() for i, cfg in enumerate(prompts, 1): out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i) print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s") records.append(detail) total_elapsed = round(time.time() - total_start, 6) avg_latency = total_elapsed / len(records) if records else 0 # 结果 JSON 结构 result_obj = { "timestamp": datetime.now().isoformat(timespec="seconds"), "model": model_path, "device": str(getattr(pipe, "device", "unknown")), "dtype": "fp16" if dtype == torch.float16 else "fp32", "count": len(records), "total_elapsed_seconds": total_elapsed, "avg_latency": avg_latency, "cases": records } with open(results_path, "w", encoding="utf-8") as f: json.dump(result_obj, f, ensure_ascii=False, indent=2) print(f"\nAll done. images: {len(records)}, total_elapsed: {total_elapsed:.3f}s, avg_latency: {avg_latency:.3f}") print(f"Results JSON: {results_path}") print(f"Images dir : {out_dir.resolve()}") if __name__ == "__main__": main()