add input config

This commit is contained in:
2025-08-28 12:49:39 +08:00
parent cdefc1873e
commit 33c57acb1b
2 changed files with 19 additions and 6 deletions

23
main.py
View File

@@ -85,8 +85,8 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
"""
prompt = cfg["prompt"]
negative_prompt = cfg.get("negative_prompt", None)
steps = int(cfg.get("num_inference_steps", 20))
guidance = float(cfg.get("guidance_scale", 7.5))
steps = int(cfg.get("num_inference_steps", 0))
guidance = float(cfg.get("guidance_scale", 0))
seed = cfg.get("seed", None)
width = cfg.get("width", None)
height = cfg.get("height", None)
@@ -102,13 +102,16 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
call_kwargs = dict(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=steps,
guidance_scale=guidance,
generator=gen,
)
if width is not None and height is not None:
call_kwargs.update({"width": int(width), "height": int(height)})
if negative_prompt is not None:
call_kwargs.update({"negative_prompt": negative_prompt})
if guidance > 0:
call_kwargs.update({"guidance_scale": guidance})
if steps > 0:
call_kwargs.update({"num_inference_steps": steps})
start = time.time()
images = pipe(**call_kwargs).images
@@ -145,6 +148,9 @@ def main():
parser.add_argument("--outdir", required=True, help="图片输出目录")
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
parser.add_argument("--negative_prompt", default=None, help="negative_prompt")
parser.add_argument("--num_inference_steps", default=0, help="num_inference_steps")
parser.add_argument("--guidance_scale", default=0, help="guidance_scale")
args = parser.parse_args()
model_path = args.model
@@ -166,6 +172,13 @@ def main():
records = []
total_start = time.time()
for i, cfg in enumerate(prompts, 1):
if args.negative_prompt:
cfg["negative_prompt"] = args.negative_prompt
if args.num_inference_steps:
cfg["num_inference_steps"] = args.num_inference_steps
if args.guidance_scale:
cfg["guidance_scale"] = args.guidance_scale
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
records.append(detail)

View File

@@ -1 +1 @@
python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16
python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16 --num_inference_steps 20