From 33c57acb1b555be46e7bfe90ba7ffe97fd59a24f Mon Sep 17 00:00:00 2001 From: ZHANG Hao Date: Thu, 28 Aug 2025 12:49:39 +0800 Subject: [PATCH] add input config --- main.py | 23 ++++++++++++++++++----- test.sh | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index d8a0c08..be84335 100644 --- a/main.py +++ b/main.py @@ -85,8 +85,8 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int): """ prompt = cfg["prompt"] negative_prompt = cfg.get("negative_prompt", None) - steps = int(cfg.get("num_inference_steps", 20)) - guidance = float(cfg.get("guidance_scale", 7.5)) + steps = int(cfg.get("num_inference_steps", 0)) + guidance = float(cfg.get("guidance_scale", 0)) seed = cfg.get("seed", None) width = cfg.get("width", None) height = cfg.get("height", None) @@ -102,13 +102,16 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int): call_kwargs = dict( prompt=prompt, - negative_prompt=negative_prompt, - num_inference_steps=steps, - guidance_scale=guidance, generator=gen, ) if width is not None and height is not None: call_kwargs.update({"width": int(width), "height": int(height)}) + if negative_prompt is not None: + call_kwargs.update({"negative_prompt": negative_prompt}) + if guidance > 0: + call_kwargs.update({"guidance_scale": guidance}) + if steps > 0: + call_kwargs.update({"num_inference_steps": steps}) start = time.time() images = pipe(**call_kwargs).images @@ -145,6 +148,9 @@ def main(): parser.add_argument("--outdir", required=True, help="图片输出目录") parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备") parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度") + parser.add_argument("--negative_prompt", default=None, help="negative_prompt") + parser.add_argument("--num_inference_steps", default=0, help="num_inference_steps") + parser.add_argument("--guidance_scale", default=0, help="guidance_scale") args = parser.parse_args() model_path = args.model @@ -166,6 +172,13 @@ def main(): records = [] total_start = time.time() for i, cfg in enumerate(prompts, 1): + if args.negative_prompt: + cfg["negative_prompt"] = args.negative_prompt + if args.num_inference_steps: + cfg["num_inference_steps"] = args.num_inference_steps + if args.guidance_scale: + cfg["guidance_scale"] = args.guidance_scale + out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i) print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s") records.append(detail) diff --git a/test.sh b/test.sh index 7bad55f..a934539 100755 --- a/test.sh +++ b/test.sh @@ -1 +1 @@ -python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16 +python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16 --num_inference_steps 20