add input config
This commit is contained in:
23
main.py
23
main.py
@@ -85,8 +85,8 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
|
|||||||
"""
|
"""
|
||||||
prompt = cfg["prompt"]
|
prompt = cfg["prompt"]
|
||||||
negative_prompt = cfg.get("negative_prompt", None)
|
negative_prompt = cfg.get("negative_prompt", None)
|
||||||
steps = int(cfg.get("num_inference_steps", 20))
|
steps = int(cfg.get("num_inference_steps", 0))
|
||||||
guidance = float(cfg.get("guidance_scale", 7.5))
|
guidance = float(cfg.get("guidance_scale", 0))
|
||||||
seed = cfg.get("seed", None)
|
seed = cfg.get("seed", None)
|
||||||
width = cfg.get("width", None)
|
width = cfg.get("width", None)
|
||||||
height = cfg.get("height", None)
|
height = cfg.get("height", None)
|
||||||
@@ -102,13 +102,16 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
|
|||||||
|
|
||||||
call_kwargs = dict(
|
call_kwargs = dict(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
num_inference_steps=steps,
|
|
||||||
guidance_scale=guidance,
|
|
||||||
generator=gen,
|
generator=gen,
|
||||||
)
|
)
|
||||||
if width is not None and height is not None:
|
if width is not None and height is not None:
|
||||||
call_kwargs.update({"width": int(width), "height": int(height)})
|
call_kwargs.update({"width": int(width), "height": int(height)})
|
||||||
|
if negative_prompt is not None:
|
||||||
|
call_kwargs.update({"negative_prompt": negative_prompt})
|
||||||
|
if guidance > 0:
|
||||||
|
call_kwargs.update({"guidance_scale": guidance})
|
||||||
|
if steps > 0:
|
||||||
|
call_kwargs.update({"num_inference_steps": steps})
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
images = pipe(**call_kwargs).images
|
images = pipe(**call_kwargs).images
|
||||||
@@ -145,6 +148,9 @@ def main():
|
|||||||
parser.add_argument("--outdir", required=True, help="图片输出目录")
|
parser.add_argument("--outdir", required=True, help="图片输出目录")
|
||||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
|
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
|
||||||
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
|
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
|
||||||
|
parser.add_argument("--negative_prompt", default=None, help="negative_prompt")
|
||||||
|
parser.add_argument("--num_inference_steps", default=0, help="num_inference_steps")
|
||||||
|
parser.add_argument("--guidance_scale", default=0, help="guidance_scale")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
model_path = args.model
|
model_path = args.model
|
||||||
@@ -166,6 +172,13 @@ def main():
|
|||||||
records = []
|
records = []
|
||||||
total_start = time.time()
|
total_start = time.time()
|
||||||
for i, cfg in enumerate(prompts, 1):
|
for i, cfg in enumerate(prompts, 1):
|
||||||
|
if args.negative_prompt:
|
||||||
|
cfg["negative_prompt"] = args.negative_prompt
|
||||||
|
if args.num_inference_steps:
|
||||||
|
cfg["num_inference_steps"] = args.num_inference_steps
|
||||||
|
if args.guidance_scale:
|
||||||
|
cfg["guidance_scale"] = args.guidance_scale
|
||||||
|
|
||||||
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
|
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
|
||||||
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
|
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
|
||||||
records.append(detail)
|
records.append(detail)
|
||||||
|
|||||||
2
test.sh
2
test.sh
@@ -1 +1 @@
|
|||||||
python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16
|
python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16 --num_inference_steps 20
|
||||||
|
|||||||
Reference in New Issue
Block a user