forked from EngineX-Ascend/enginex-ascend-910-diffusers
add input config
This commit is contained in:
23
main.py
23
main.py
@@ -85,8 +85,8 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
|
||||
"""
|
||||
prompt = cfg["prompt"]
|
||||
negative_prompt = cfg.get("negative_prompt", None)
|
||||
steps = int(cfg.get("num_inference_steps", 20))
|
||||
guidance = float(cfg.get("guidance_scale", 7.5))
|
||||
steps = int(cfg.get("num_inference_steps", 0))
|
||||
guidance = float(cfg.get("guidance_scale", 0))
|
||||
seed = cfg.get("seed", None)
|
||||
width = cfg.get("width", None)
|
||||
height = cfg.get("height", None)
|
||||
@@ -102,13 +102,16 @@ def generate_one(pipe: DiffusionPipeline, cfg: dict, out_dir: Path, index: int):
|
||||
|
||||
call_kwargs = dict(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance,
|
||||
generator=gen,
|
||||
)
|
||||
if width is not None and height is not None:
|
||||
call_kwargs.update({"width": int(width), "height": int(height)})
|
||||
if negative_prompt is not None:
|
||||
call_kwargs.update({"negative_prompt": negative_prompt})
|
||||
if guidance > 0:
|
||||
call_kwargs.update({"guidance_scale": guidance})
|
||||
if steps > 0:
|
||||
call_kwargs.update({"num_inference_steps": steps})
|
||||
|
||||
start = time.time()
|
||||
images = pipe(**call_kwargs).images
|
||||
@@ -145,6 +148,9 @@ def main():
|
||||
parser.add_argument("--outdir", required=True, help="图片输出目录")
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
|
||||
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
|
||||
parser.add_argument("--negative_prompt", default=None, help="negative_prompt")
|
||||
parser.add_argument("--num_inference_steps", default=0, help="num_inference_steps")
|
||||
parser.add_argument("--guidance_scale", default=0, help="guidance_scale")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model
|
||||
@@ -166,6 +172,13 @@ def main():
|
||||
records = []
|
||||
total_start = time.time()
|
||||
for i, cfg in enumerate(prompts, 1):
|
||||
if args.negative_prompt:
|
||||
cfg["negative_prompt"] = args.negative_prompt
|
||||
if args.num_inference_steps:
|
||||
cfg["num_inference_steps"] = args.num_inference_steps
|
||||
if args.guidance_scale:
|
||||
cfg["guidance_scale"] = args.guidance_scale
|
||||
|
||||
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
|
||||
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
|
||||
records.append(detail)
|
||||
|
||||
2
test.sh
2
test.sh
@@ -1 +1 @@
|
||||
python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16
|
||||
python3 main.py --model "/mnt/contest_ceph/zhanghao/models/stable-diffusion-v1-5" --json "dataset.json" --results "results.json" --outdir "output" --device cuda --dtype fp16 --num_inference_steps 20
|
||||
|
||||
Reference in New Issue
Block a user